Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 41 additions & 34 deletions clang-tools-extra/clang-tidy/llvm/UseNewMLIROpBuilderCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ namespace {
using namespace ::clang::ast_matchers;
using namespace ::clang::transformer;

EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
RangeSelector CallArgs) {
EditGenerator rewrite(RangeSelector Call, RangeSelector Builder) {
// This is using an EditGenerator rather than ASTEdit as we want to warn even
// if in macro.
return [Call = std::move(Call), Builder = std::move(Builder),
CallArgs =
std::move(CallArgs)](const MatchFinder::MatchResult &Result)
return [Call = std::move(Call),
Builder = std::move(Builder)](const MatchFinder::MatchResult &Result)
-> Expected<SmallVector<transformer::Edit, 1>> {
Expected<CharSourceRange> CallRange = Call(Result);
if (!CallRange)
Expand All @@ -54,7 +52,7 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
auto NextToken = [&](std::optional<Token> CurrentToken) {
if (!CurrentToken)
return CurrentToken;
if (CurrentToken->getEndLoc() >= CallRange->getEnd())
if (CurrentToken->is(clang::tok::eof))
return std::optional<Token>();
return clang::Lexer::findNextToken(CurrentToken->getLocation(), SM,
LangOpts);
Expand All @@ -68,9 +66,10 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
"missing '<' token");
}

std::optional<Token> EndToken = NextToken(LessToken);
for (std::optional<Token> GreaterToken = NextToken(EndToken);
GreaterToken && GreaterToken->getKind() != clang::tok::greater;
std::optional<Token> GreaterToken = NextToken(EndToken);
for (; GreaterToken && GreaterToken->getKind() != clang::tok::greater;
GreaterToken = NextToken(GreaterToken)) {
EndToken = GreaterToken;
}
Expand All @@ -79,12 +78,21 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,
"missing '>' token");
}

std::optional<Token> ArgStart = NextToken(GreaterToken);
if (!ArgStart || ArgStart->getKind() != clang::tok::l_paren) {
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
"missing '(' token");
}
std::optional<Token> Arg = NextToken(ArgStart);
if (!Arg) {
return llvm::make_error<llvm::StringError>(llvm::errc::invalid_argument,
"unexpected end of file");
}
const bool HasArgs = Arg->getKind() != clang::tok::r_paren;

Expected<CharSourceRange> BuilderRange = Builder(Result);
if (!BuilderRange)
return BuilderRange.takeError();
Expected<CharSourceRange> CallArgsRange = CallArgs(Result);
if (!CallArgsRange)
return CallArgsRange.takeError();

// Helper for concatting below.
auto GetText = [&](const CharSourceRange &Range) {
Expand All @@ -93,18 +101,19 @@ EditGenerator rewrite(RangeSelector Call, RangeSelector Builder,

Edit Replace;
Replace.Kind = EditKind::Range;
Replace.Range = *CallRange;
std::string CallArgsStr;
// Only emit args if there are any.
if (auto CallArgsText = GetText(*CallArgsRange).ltrim();
!CallArgsText.rtrim().empty()) {
CallArgsStr = llvm::formatv(", {}", CallArgsText);
Replace.Range.setBegin(CallRange->getBegin());
Replace.Range.setEnd(ArgStart->getEndLoc());
const Expr *BuilderExpr = Result.Nodes.getNodeAs<Expr>("builder");
std::string BuilderText = GetText(*BuilderRange).str();
if (BuilderExpr->getType()->isPointerType()) {
BuilderText = BuilderExpr->isImplicitCXXThis()
? "*this"
: llvm::formatv("*{}", BuilderText).str();
}
Replace.Replacement =
llvm::formatv("{}::create({}{})",
GetText(CharSourceRange::getTokenRange(
LessToken->getEndLoc(), EndToken->getLastLoc())),
GetText(*BuilderRange), CallArgsStr);
const StringRef OpType = GetText(CharSourceRange::getTokenRange(
LessToken->getEndLoc(), EndToken->getLastLoc()));
Replace.Replacement = llvm::formatv("{}::create({}{}", OpType, BuilderText,
HasArgs ? ", " : "");

return SmallVector<Edit, 1>({Replace});
};
Expand All @@ -114,20 +123,18 @@ RewriteRuleWith<std::string> useNewMlirOpBuilderCheckRule() {
Stencil message = cat("use 'OpType::create(builder, ...)' instead of "
"'builder.create<OpType>(...)'");
// Match a create call on an OpBuilder.
ast_matchers::internal::Matcher<Stmt> base =
cxxMemberCallExpr(
on(expr(hasType(
cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"))))
.bind("builder")),
callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()))),
callee(cxxMethodDecl(hasName("create"))))
.bind("call");
auto BuilderType = cxxRecordDecl(isSameOrDerivedFrom("::mlir::OpBuilder"));
ast_matchers::internal::Matcher<Stmt> Base = cxxMemberCallExpr(
on(expr(anyOf(hasType(BuilderType), hasType(pointsTo(BuilderType))))
.bind("builder")),
callee(expr().bind("call")),
callee(cxxMethodDecl(hasTemplateArgument(0, templateArgument()),
hasName("create"))));
return applyFirst(
// Attempt rewrite given an lvalue builder, else just warn.
{makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), base),
rewrite(node("call"), node("builder"), callArgs("call")),
message),
makeRule(base, noopEdit(node("call")), message)});
{makeRule(cxxMemberCallExpr(unless(on(cxxTemporaryObjectExpr())), Base),
rewrite(node("call"), node("builder")), message),
makeRule(Base, noopEdit(node("call")), message)});
}
} // namespace

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

namespace mlir {
class Location {};
class Value {};
class OpBuilder {
public:
template <typename OpTy, typename... Args>
Expand All @@ -28,6 +29,13 @@ struct NamedOp {
static NamedOp create(OpBuilder &builder, Location location, const char* name) {
return NamedOp(name);
}
Value getResult() { return Value(); }
};
struct OperandOp {
OperandOp(Value val) {}
static OperandOp create(OpBuilder &builder, Location location, Value val) {
return OperandOp(val);
}
};
} // namespace mlir

Expand All @@ -40,13 +48,24 @@ void g(mlir::OpBuilder &b) {
b.create<T>(b.getUnknownLoc(), "gaz");
}

class CustomBuilder : public mlir::ImplicitLocOpBuilder {
public:
mlir::NamedOp f(const char *name) {
// CHECK-MESSAGES: :[[@LINE+2]]:12: warning: use 'OpType::create(builder, ...)'
// CHECK-FIXES: mlir::NamedOp::create(*this, name);
return create<mlir::NamedOp>(name);
}
};

void f() {
mlir::OpBuilder builder;
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
// CHECK-FIXES: mlir:: ModuleOp::create(builder, builder.getUnknownLoc())
builder.create<mlir:: ModuleOp>(builder.getUnknownLoc());

using mlir::NamedOp;
using mlir::OperandOp;

// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
// CHECK-FIXES: NamedOp::create(builder, builder.getUnknownLoc(), "baz")
builder.create<NamedOp>(builder.getUnknownLoc(), "baz");
Expand All @@ -56,7 +75,7 @@ void f() {
// CHECK-FIXES: builder.getUnknownLoc(),
// CHECK-FIXES: "caz")
builder.
create<NamedOp>(
create<NamedOp> (
builder.getUnknownLoc(),
"caz");

Expand All @@ -67,10 +86,25 @@ void f() {

mlir::ImplicitLocOpBuilder ib;
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
// CHECK-FIXES: mlir::ModuleOp::create(ib)
// CHECK-FIXES: mlir::ModuleOp::create(ib )
ib.create<mlir::ModuleOp>( );

// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
// CHECK-FIXES: mlir::OpBuilder().create<mlir::ModuleOp>(builder.getUnknownLoc());
mlir::OpBuilder().create<mlir::ModuleOp>(builder.getUnknownLoc());

auto *p = &builder;
// CHECK-MESSAGES: :[[@LINE+2]]:3: warning: use 'OpType::create(builder, ...)'
// CHECK-FIXES: NamedOp::create(*p, builder.getUnknownLoc(), "eaz")
p->create<NamedOp>(builder.getUnknownLoc(), "eaz");

CustomBuilder cb;
cb.f("faz");

// CHECK-MESSAGES: :[[@LINE+4]]:3: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
// CHECK-FIXES: OperandOp::create(builder, builder.getUnknownLoc(),
// CHECK-MESSAGES: :[[@LINE+3]]:5: warning: use 'OpType::create(builder, ...)' instead of 'builder.create<OpType>(...)' [llvm-use-new-mlir-op-builder]
// CHECK-FIXES: NamedOp::create(builder,
builder.create<OperandOp>(builder.getUnknownLoc(),
builder.create<NamedOp>(builder.getUnknownLoc(), "gaz").getResult());
}
Loading