Skip to content

Commit

Permalink
Add caches for matchers
Browse files Browse the repository at this point in the history
  • Loading branch information
Mytherin committed Nov 21, 2024
1 parent c870b72 commit 1251945
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 41 deletions.
4 changes: 2 additions & 2 deletions extension/autocomplete/include/matcher.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ class Matcher {
virtual ~Matcher() = default;

//! Match
virtual MatchResultType Match(MatchState &state) = 0;
virtual SuggestionType AddSuggestion(MatchState &state) = 0;
virtual MatchResultType Match(MatchState &state) const = 0;
virtual SuggestionType AddSuggestion(MatchState &state) const = 0;

static Matcher &RootMatcher(MatcherAllocator &allocator);
};
Expand Down
118 changes: 79 additions & 39 deletions extension/autocomplete/matcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class KeywordMatcher : public Matcher {
explicit KeywordMatcher(string keyword_p, int32_t score_bonus = 0, char extra_char = '\0') : keyword(std::move(keyword_p)), score_bonus(score_bonus), extra_char(extra_char) {
}

MatchResultType Match(MatchState &state) override {
MatchResultType Match(MatchState &state) const override {
auto &token = state.tokens[state.token_index];
if (StringUtil::CIEquals(keyword, token.text)) {
// move to the next token
Expand All @@ -18,7 +18,7 @@ class KeywordMatcher : public Matcher {
}
}

SuggestionType AddSuggestion(MatchState &state) override {
SuggestionType AddSuggestion(MatchState &state) const override {
AutoCompleteCandidate candidate(keyword, score_bonus, CandidateMatchCase::MATCH_CASE);
candidate.extra_char = extra_char;
state.suggestions.emplace_back(std::move(candidate));
Expand All @@ -36,7 +36,7 @@ class ListMatcher : public Matcher {
explicit ListMatcher(vector<reference<Matcher>> matchers_p) : matchers(std::move(matchers_p)) {
}

MatchResultType Match(MatchState &state) override {
MatchResultType Match(MatchState &state) const override {
MatchState list_state(state);
for (idx_t child_idx = 0; child_idx < matchers.size(); child_idx++) {
auto &child_matcher = matchers[child_idx].get();
Expand All @@ -63,7 +63,7 @@ class ListMatcher : public Matcher {
return MatchResultType::SUCCESS;
}

SuggestionType AddSuggestion(MatchState &state) override {
SuggestionType AddSuggestion(MatchState &state) const override {
for (auto &matcher : matchers) {
auto suggestion_result = matcher.get().AddSuggestion(state);
if (suggestion_result == SuggestionType::MANDATORY) {
Expand All @@ -81,10 +81,10 @@ class ListMatcher : public Matcher {

class OptionalMatcher : public Matcher {
public:
OptionalMatcher(Matcher &matcher_p) : matcher(matcher_p) {
explicit OptionalMatcher(Matcher &matcher_p) : matcher(matcher_p) {
}

MatchResultType Match(MatchState &state) override {
MatchResultType Match(MatchState &state) const override {
MatchState child_state(state);
auto child_match = matcher.Match(child_state);
if (child_match != MatchResultType::SUCCESS) {
Expand All @@ -96,7 +96,7 @@ class OptionalMatcher : public Matcher {
return MatchResultType::SUCCESS;
}

SuggestionType AddSuggestion(MatchState &state) override {
SuggestionType AddSuggestion(MatchState &state) const override {
matcher.AddSuggestion(state);
return SuggestionType::OPTIONAL;
}
Expand All @@ -110,7 +110,7 @@ class ChoiceMatcher : public Matcher {
explicit ChoiceMatcher(vector<reference<Matcher>> matchers_p) : matchers(std::move(matchers_p)) {
}

MatchResultType Match(MatchState &state) override {
MatchResultType Match(MatchState &state) const override {
for (auto &child_matcher : matchers) {
MatchState choice_state(state);
auto child_result = child_matcher.get().Match(choice_state);
Expand All @@ -123,7 +123,7 @@ class ChoiceMatcher : public Matcher {
return MatchResultType::FAIL;
}

SuggestionType AddSuggestion(MatchState &state) override {
SuggestionType AddSuggestion(MatchState &state) const override {
for (auto &child_matcher : matchers) {
child_matcher.get().AddSuggestion(state);
}
Expand All @@ -140,7 +140,7 @@ class RepeatMatcher : public Matcher {
: element(element_p), separator(separator_p) {
}

MatchResultType Match(MatchState &state) override {
MatchResultType Match(MatchState &state) const override {
MatchState repeat_state(state);
while (true) {
// we exhausted the tokens - suggest the element
Expand Down Expand Up @@ -180,7 +180,7 @@ class RepeatMatcher : public Matcher {
}
}

SuggestionType AddSuggestion(MatchState &state) override {
SuggestionType AddSuggestion(MatchState &state) const override {
element.AddSuggestion(state);
return SuggestionType::MANDATORY;
}
Expand All @@ -192,15 +192,15 @@ class RepeatMatcher : public Matcher {

class VariableMatcher : public Matcher {
public:
VariableMatcher(SuggestionState suggestion_type) : suggestion_type(suggestion_type) {
explicit VariableMatcher(SuggestionState suggestion_type) : suggestion_type(suggestion_type) {
}

MatchResultType Match(MatchState &state) override {
MatchResultType Match(MatchState &state) const override {
state.token_index++;
return MatchResultType::SUCCESS;
}

SuggestionType AddSuggestion(MatchState &state) override {
SuggestionType AddSuggestion(MatchState &state) const override {
state.suggestions.emplace_back(suggestion_type);
return SuggestionType::MANDATORY;
}
Expand All @@ -214,6 +214,11 @@ Matcher &MatcherAllocator::Allocate(unique_ptr<Matcher> matcher) {
return result;
}

enum class CachedMatcherType {
IF_NOT_EXISTS,
QUALIFIED_TABLE_NAME
};

//! Class for building matchers
class MatcherFactory {
public:
Expand All @@ -223,20 +228,21 @@ class MatcherFactory {

private:
// Base primitives
Matcher &Keyword(const string &keyword, int32_t score_bonus= 0, char extra_char = ' ');
Matcher &List(vector<reference<Matcher>> matchers);
Matcher &Choice(vector<reference<Matcher>> matchers);
Matcher &Optional(Matcher &matcher);
Matcher &Repeat(Matcher &matcher, optional_ptr<Matcher> separator);
Matcher &Variable();
Matcher &CatalogName();
Matcher &SchemaName();
Matcher &TypeName();
Matcher &TableName();
Matcher &Keyword(const string &keyword, int32_t score_bonus= 0, char extra_char = ' ') const;
Matcher &List(vector<reference<Matcher>> matchers) const;
Matcher &Choice(vector<reference<Matcher>> matchers) const;
Matcher &Optional(Matcher &matcher) const;
Matcher &Repeat(Matcher &matcher, optional_ptr<Matcher> separator) const;
Matcher &Variable() const;
Matcher &CatalogName() const;
Matcher &SchemaName() const;
Matcher &TypeName() const;
Matcher &TableName() const;

private:
// Matchers
Matcher &TemporaryOrReplace();
Matcher &OrReplace();
Matcher &Temporary();
Matcher &IfNotExists();
Matcher &CatalogQualification();
Matcher &SchemaQualification();
Expand All @@ -261,63 +267,91 @@ class MatcherFactory {
Matcher &CreateMatchers();
Matcher &CreateStatementMatcher();

private:
optional_ptr<Matcher> CheckCache(CachedMatcherType matcher_type) {
auto entry = matcher_cache.find(matcher_type);
if (entry != matcher_cache.end()) {
// entry was already created - return the entry
return &entry->second.get();
}
return nullptr;
}

Matcher &Cache(CachedMatcherType matcher_type, Matcher &matcher) {
matcher_cache.insert(make_pair(matcher_type, reference<Matcher>(matcher)));
return matcher;
}

private:
MatcherAllocator &allocator;
unordered_map<CachedMatcherType, reference<Matcher>> matcher_cache;
};

Matcher &MatcherFactory::Keyword(const string &keyword, int32_t score_bonus, char extra_char) {
Matcher &MatcherFactory::Keyword(const string &keyword, int32_t score_bonus, char extra_char) const {
return allocator.Allocate(make_uniq<KeywordMatcher>(keyword, score_bonus, extra_char));
}

Matcher &MatcherFactory::List(vector<reference<Matcher>> matchers) {
Matcher &MatcherFactory::List(vector<reference<Matcher>> matchers) const {
return allocator.Allocate(make_uniq<ListMatcher>(std::move(matchers)));
}

Matcher &MatcherFactory::Choice(vector<reference<Matcher>> matchers) {
Matcher &MatcherFactory::Choice(vector<reference<Matcher>> matchers) const {
return allocator.Allocate(make_uniq<ChoiceMatcher>(std::move(matchers)));
}

Matcher &MatcherFactory::Optional(Matcher &matcher) {
Matcher &MatcherFactory::Optional(Matcher &matcher) const {
return allocator.Allocate(make_uniq<OptionalMatcher>(matcher));
}

Matcher &MatcherFactory::Repeat(Matcher &matcher, optional_ptr<Matcher> separator) {
Matcher &MatcherFactory::Repeat(Matcher &matcher, optional_ptr<Matcher> separator) const {
return allocator.Allocate(make_uniq<RepeatMatcher>(matcher, separator));
}

Matcher &MatcherFactory::Variable() {
Matcher &MatcherFactory::Variable() const {
return allocator.Allocate(make_uniq<VariableMatcher>(SuggestionState::SUGGEST_VARIABLE));
}

Matcher &MatcherFactory::CatalogName() {
Matcher &MatcherFactory::CatalogName() const {
return allocator.Allocate(make_uniq<VariableMatcher>(SuggestionState::SUGGEST_CATALOG_NAME));
}

Matcher &MatcherFactory::SchemaName() {
Matcher &MatcherFactory::SchemaName() const {
return allocator.Allocate(make_uniq<VariableMatcher>(SuggestionState::SUGGEST_SCHEMA_NAME));
}

Matcher &MatcherFactory::TypeName() {
Matcher &MatcherFactory::TypeName() const {
return allocator.Allocate(make_uniq<VariableMatcher>(SuggestionState::SUGGEST_TYPE_NAME));
}

Matcher &MatcherFactory::TableName() {
Matcher &MatcherFactory::TableName() const {
return allocator.Allocate(make_uniq<VariableMatcher>(SuggestionState::SUGGEST_TABLE_NAME));
}

Matcher &MatcherFactory::TemporaryOrReplace() {
Matcher &MatcherFactory::OrReplace() {
vector<reference<Matcher>> m;
m.push_back(Keyword("OR"));
m.push_back(Keyword("REPLACE"));
return Optional(List(std::move(m)));
}

Matcher &MatcherFactory::Temporary() {
vector<reference<Matcher>> m;
m.push_back(Keyword("TEMP"));
m.push_back(Keyword("TEMPORARY"));
return Optional(Choice(std::move(m)));
}

Matcher &MatcherFactory::IfNotExists() {
auto cache_type = CachedMatcherType::IF_NOT_EXISTS;
auto cache_entry = CheckCache(cache_type);
if (cache_entry) {
return *cache_entry;
}
vector<reference<Matcher>> m;
m.push_back(Keyword("IF"));
m.push_back(Keyword("NOT"));
m.push_back(Keyword("EXISTS"));
return Optional(List(std::move(m)));
return Cache(cache_type, Optional(List(std::move(m))));
}

Matcher &MatcherFactory::CatalogQualification() {
Expand All @@ -342,11 +376,16 @@ Matcher &MatcherFactory::QualifiedSchemaName() {
}

Matcher &MatcherFactory::QualifiedTableName() {
auto cache_type = CachedMatcherType::QUALIFIED_TABLE_NAME;
auto cache_entry = CheckCache(cache_type);
if (cache_entry) {
return *cache_entry;
}
vector<reference<Matcher>> m;
m.push_back(Optional(CatalogQualification()));
m.push_back(Optional(SchemaQualification()));
m.push_back(Variable());
return List(std::move(m));
return Cache(cache_type, List(std::move(m)));
}

Matcher &MatcherFactory::NotNullConstraint() {
Expand Down Expand Up @@ -477,7 +516,8 @@ Matcher &MatcherFactory::CreateMatchers() {
Matcher &MatcherFactory::CreateStatementMatcher() {
vector<reference<Matcher>> m;
m.push_back(Keyword("CREATE"));
m.push_back(TemporaryOrReplace());
m.push_back(OrReplace());
m.push_back(Temporary());
m.push_back(CreateMatchers());
return List(std::move(m));
}
Expand Down
5 changes: 5 additions & 0 deletions test/sql/function/autocomplete/create_table.test
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ FROM sql_auto_complete('CREATE T') LIMIT 1;
----
TABLE 7

query II
FROM sql_auto_complete('CREATE OR RE') LIMIT 1;
----
REPLACE 10

query II
FROM sql_auto_complete('create ta') LIMIT 1;
----
Expand Down

0 comments on commit 1251945

Please sign in to comment.