Skip to content

Commit 02d9f4d

Browse files
devajithvsjpienaar
authored andcommitted
[mlir][mlir-query] Introduce mlir-query tool with autocomplete support
This commit adds the initial version of the mlir-query tool, which leverages the pre-existing matchers defined in mlir/include/mlir/IR/Matchers.h The tool provides the following set of basic queries: hasOpAttrName(string) hasOpName(string) isConstantOp() isNegInfFloat() isNegZeroFloat() isNonZero() isOne() isOneFloat() isPosInfFloat() isPosZeroFloat() isZero() isZeroFloat() Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D155127
1 parent 1673a1b commit 02d9f4d

31 files changed

+2650
-0
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
//===--- ErrorBuilder.h - Helper for building error messages ----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// ErrorBuilder to manage error messages.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_ERRORBUILDER_H
14+
#define MLIR_TOOLS_MLIRQUERY_MATCHER_ERRORBUILDER_H
15+
16+
#include "llvm/ADT/StringRef.h"
17+
#include "llvm/ADT/Twine.h"
18+
#include <initializer_list>
19+
20+
namespace mlir::query::matcher::internal {
21+
class Diagnostics;
22+
23+
// Represents the line and column numbers in a source query.
24+
struct SourceLocation {
25+
unsigned line{};
26+
unsigned column{};
27+
};
28+
29+
// Represents a range in a source query, defined by its start and end locations.
30+
struct SourceRange {
31+
SourceLocation start{};
32+
SourceLocation end{};
33+
};
34+
35+
// All errors from the system.
36+
enum class ErrorType {
37+
None,
38+
39+
// Parser Errors
40+
ParserFailedToBuildMatcher,
41+
ParserInvalidToken,
42+
ParserNoCloseParen,
43+
ParserNoCode,
44+
ParserNoComma,
45+
ParserNoOpenParen,
46+
ParserNotAMatcher,
47+
ParserOverloadedType,
48+
ParserStringError,
49+
ParserTrailingCode,
50+
51+
// Registry Errors
52+
RegistryMatcherNotFound,
53+
RegistryValueNotFound,
54+
RegistryWrongArgCount,
55+
RegistryWrongArgType
56+
};
57+
58+
void addError(Diagnostics *error, SourceRange range, ErrorType errorType,
59+
std::initializer_list<llvm::Twine> errorTexts);
60+
61+
} // namespace mlir::query::matcher::internal
62+
63+
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_ERRORBUILDER_H
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
//===--- Marshallers.h - Generic matcher function marshallers ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains function templates and classes to wrap matcher construct
10+
// functions. It provides a collection of template function and classes that
11+
// present a generic marshalling layer on top of matcher construct functions.
12+
// The registry uses these to export all marshaller constructors with a uniform
13+
// interface. This mechanism takes inspiration from clang-query.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
17+
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
18+
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
19+
20+
#include "ErrorBuilder.h"
21+
#include "VariantValue.h"
22+
#include "llvm/ADT/ArrayRef.h"
23+
#include "llvm/ADT/StringRef.h"
24+
25+
namespace mlir::query::matcher::internal {
26+
27+
// Helper template class for jumping from argument type to the correct is/get
28+
// functions in VariantValue. This is used for verifying and extracting the
29+
// matcher arguments.
30+
template <class T>
31+
struct ArgTypeTraits;
32+
template <class T>
33+
struct ArgTypeTraits<const T &> : public ArgTypeTraits<T> {};
34+
35+
template <>
36+
struct ArgTypeTraits<llvm::StringRef> {
37+
38+
static bool hasCorrectType(const VariantValue &value) {
39+
return value.isString();
40+
}
41+
42+
static const llvm::StringRef &get(const VariantValue &value) {
43+
return value.getString();
44+
}
45+
46+
static ArgKind getKind() { return ArgKind::String; }
47+
48+
static std::optional<std::string> getBestGuess(const VariantValue &) {
49+
return std::nullopt;
50+
}
51+
};
52+
53+
template <>
54+
struct ArgTypeTraits<DynMatcher> {
55+
56+
static bool hasCorrectType(const VariantValue &value) {
57+
return value.isMatcher();
58+
}
59+
60+
static DynMatcher get(const VariantValue &value) {
61+
return *value.getMatcher().getDynMatcher();
62+
}
63+
64+
static ArgKind getKind() { return ArgKind::Matcher; }
65+
66+
static std::optional<std::string> getBestGuess(const VariantValue &) {
67+
return std::nullopt;
68+
}
69+
};
70+
71+
// Interface for generic matcher descriptor.
72+
// Offers a create() method that constructs the matcher from the provided
73+
// arguments.
74+
class MatcherDescriptor {
75+
public:
76+
virtual ~MatcherDescriptor() = default;
77+
virtual VariantMatcher create(SourceRange nameRange,
78+
const llvm::ArrayRef<ParserValue> args,
79+
Diagnostics *error) const = 0;
80+
81+
// Returns the number of arguments accepted by the matcher.
82+
virtual unsigned getNumArgs() const = 0;
83+
84+
// Append the set of argument types accepted for argument 'argNo' to
85+
// 'argKinds'.
86+
virtual void getArgKinds(unsigned argNo,
87+
std::vector<ArgKind> &argKinds) const = 0;
88+
};
89+
90+
class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
91+
public:
92+
using MarshallerType = VariantMatcher (*)(void (*matcherFunc)(),
93+
llvm::StringRef matcherName,
94+
SourceRange nameRange,
95+
llvm::ArrayRef<ParserValue> args,
96+
Diagnostics *error);
97+
98+
// Marshaller Function to unpack the arguments and call Func. Func is the
99+
// Matcher construct function. This is the function that the matcher
100+
// expressions would use to create the matcher.
101+
FixedArgCountMatcherDescriptor(MarshallerType marshaller,
102+
void (*matcherFunc)(),
103+
llvm::StringRef matcherName,
104+
llvm::ArrayRef<ArgKind> argKinds)
105+
: marshaller(marshaller), matcherFunc(matcherFunc),
106+
matcherName(matcherName), argKinds(argKinds.begin(), argKinds.end()) {}
107+
108+
VariantMatcher create(SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
109+
Diagnostics *error) const override {
110+
return marshaller(matcherFunc, matcherName, nameRange, args, error);
111+
}
112+
113+
unsigned getNumArgs() const override { return argKinds.size(); }
114+
115+
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
116+
kinds.push_back(argKinds[argNo]);
117+
}
118+
119+
private:
120+
const MarshallerType marshaller;
121+
void (*const matcherFunc)();
122+
const llvm::StringRef matcherName;
123+
const std::vector<ArgKind> argKinds;
124+
};
125+
126+
// Helper function to check if argument count matches expected count
127+
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
128+
llvm::ArrayRef<ParserValue> args,
129+
Diagnostics *error) {
130+
if (args.size() != expectedArgCount) {
131+
addError(error, nameRange, ErrorType::RegistryWrongArgCount,
132+
{llvm::Twine(expectedArgCount), llvm::Twine(args.size())});
133+
return false;
134+
}
135+
return true;
136+
}
137+
138+
// Helper function for checking argument type
139+
template <typename ArgType, size_t Index>
140+
inline bool checkArgTypeAtIndex(llvm::StringRef matcherName,
141+
llvm::ArrayRef<ParserValue> args,
142+
Diagnostics *error) {
143+
if (!ArgTypeTraits<ArgType>::hasCorrectType(args[Index].value)) {
144+
addError(error, args[Index].range, ErrorType::RegistryWrongArgType,
145+
{llvm::Twine(matcherName), llvm::Twine(Index + 1)});
146+
return false;
147+
}
148+
return true;
149+
}
150+
151+
// Marshaller function for fixed number of arguments
152+
template <typename ReturnType, typename... ArgTypes, size_t... Is>
153+
static VariantMatcher
154+
matcherMarshallFixedImpl(void (*matcherFunc)(), llvm::StringRef matcherName,
155+
SourceRange nameRange,
156+
llvm::ArrayRef<ParserValue> args, Diagnostics *error,
157+
std::index_sequence<Is...>) {
158+
using FuncType = ReturnType (*)(ArgTypes...);
159+
160+
// Check if the argument count matches the expected count
161+
if (!checkArgCount(nameRange, sizeof...(ArgTypes), args, error))
162+
return VariantMatcher();
163+
164+
// Check if each argument at the corresponding index has the correct type
165+
if ((... && checkArgTypeAtIndex<ArgTypes, Is>(matcherName, args, error))) {
166+
ReturnType fnPointer = reinterpret_cast<FuncType>(matcherFunc)(
167+
ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
168+
return VariantMatcher::SingleMatcher(
169+
*DynMatcher::constructDynMatcherFromMatcherFn(fnPointer));
170+
}
171+
172+
return VariantMatcher();
173+
}
174+
175+
template <typename ReturnType, typename... ArgTypes>
176+
static VariantMatcher
177+
matcherMarshallFixed(void (*matcherFunc)(), llvm::StringRef matcherName,
178+
SourceRange nameRange, llvm::ArrayRef<ParserValue> args,
179+
Diagnostics *error) {
180+
return matcherMarshallFixedImpl<ReturnType, ArgTypes...>(
181+
matcherFunc, matcherName, nameRange, args, error,
182+
std::index_sequence_for<ArgTypes...>{});
183+
}
184+
185+
// Fixed number of arguments overload
186+
template <typename ReturnType, typename... ArgTypes>
187+
std::unique_ptr<MatcherDescriptor>
188+
makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
189+
llvm::StringRef matcherName) {
190+
// Create a vector of argument kinds
191+
std::vector<ArgKind> argKinds = {ArgTypeTraits<ArgTypes>::getKind()...};
192+
return std::make_unique<FixedArgCountMatcherDescriptor>(
193+
matcherMarshallFixed<ReturnType, ArgTypes...>,
194+
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
195+
}
196+
197+
} // namespace mlir::query::matcher::internal
198+
199+
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//===- MatchFinder.h - ------------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains the MatchFinder class, which is used to find operations
10+
// that match a given matcher.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
15+
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
16+
17+
#include "MatchersInternal.h"
18+
19+
namespace mlir::query::matcher {
20+
21+
// MatchFinder is used to find all operations that match a given matcher.
22+
class MatchFinder {
23+
public:
24+
// Returns all operations that match the given matcher.
25+
static std::vector<Operation *> getMatches(Operation *root,
26+
DynMatcher matcher) {
27+
std::vector<Operation *> matches;
28+
29+
// Simple match finding with walk.
30+
root->walk([&](Operation *subOp) {
31+
if (matcher.match(subOp))
32+
matches.push_back(subOp);
33+
});
34+
35+
return matches;
36+
}
37+
};
38+
39+
} // namespace mlir::query::matcher
40+
41+
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Implements the base layer of the matcher framework.
10+
//
11+
// Matchers are methods that return a Matcher which provides a method
12+
// match(Operation *op)
13+
//
14+
// The matcher functions are defined in include/mlir/IR/Matchers.h.
15+
// This file contains the wrapper classes needed to construct matchers for
16+
// mlir-query.
17+
//
18+
//===----------------------------------------------------------------------===//
19+
20+
#ifndef MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
21+
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
22+
23+
#include "mlir/IR/Matchers.h"
24+
#include "llvm/ADT/IntrusiveRefCntPtr.h"
25+
26+
namespace mlir::query::matcher {
27+
28+
// Generic interface for matchers on an MLIR operation.
29+
class MatcherInterface
30+
: public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
31+
public:
32+
virtual ~MatcherInterface() = default;
33+
34+
virtual bool match(Operation *op) = 0;
35+
};
36+
37+
// MatcherFnImpl takes a matcher function object and implements
38+
// MatcherInterface.
39+
template <typename MatcherFn>
40+
class MatcherFnImpl : public MatcherInterface {
41+
public:
42+
MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
43+
bool match(Operation *op) override { return matcherFn.match(op); }
44+
45+
private:
46+
MatcherFn matcherFn;
47+
};
48+
49+
// Matcher wraps a MatcherInterface implementation and provides a match()
50+
// method that redirects calls to the underlying implementation.
51+
class DynMatcher {
52+
public:
53+
// Takes ownership of the provided implementation pointer.
54+
DynMatcher(MatcherInterface *implementation)
55+
: implementation(implementation) {}
56+
57+
template <typename MatcherFn>
58+
static std::unique_ptr<DynMatcher>
59+
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
60+
auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn);
61+
return std::make_unique<DynMatcher>(impl.release());
62+
}
63+
64+
bool match(Operation *op) const { return implementation->match(op); }
65+
66+
private:
67+
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
68+
};
69+
70+
} // namespace mlir::query::matcher
71+
72+
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H

0 commit comments

Comments
 (0)