From 5c4f5ee83c8e6456e9bbc10010529d8b32532fc2 Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Sat, 28 Mar 2026 17:46:50 -0400 Subject: [PATCH] break sqlcompleter.py find_matches() into units and add test coverage. This also changes find_matches() into an instance method, but we could consider changing find_matches() and many others into static methods. Motivation: smaller units make the code more testable and more amenable to agentic coding. --- changelog.md | 5 + mycli/sqlcompleter.py | 221 +++++++----- .../pytests/test_sqlcompleter_find_matches.py | 338 ++++++++++++++++++ 3 files changed, 478 insertions(+), 86 deletions(-) create mode 100644 test/pytests/test_sqlcompleter_find_matches.py diff --git a/changelog.md b/changelog.md index 4c9602e5..08e9fcbf 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,11 @@ Features * Continue to expand TIPS. +Internal +--------- +* Refactor `find_matches()` into smaller logical units. + + 1.67.1 (2026/03/28) ============== diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index 44e1bcb2..d5429f42 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -19,6 +19,7 @@ from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS _logger = logging.getLogger(__name__) +_CASE_CHANGE_PAT = re.compile('(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])') class Fuzziness(IntEnum): @@ -1173,8 +1174,135 @@ def reset_completions(self) -> None: } self.all_completions = set(self.keywords + self.functions) - @staticmethod + def maybe_quote_identifier(self, item: str) -> str: + if item.startswith('`'): + return item + if item == '*': + return item + return '`' + item + '`' + + def quote_collection_if_needed( + self, + text: str, + collection: Collection[Any], + text_before_cursor: str, + ) -> Collection[Any]: + # checking text.startswith() first is an optimization; is_inside_quotes() covers more cases + if text.startswith('`') or is_inside_quotes(text_before_cursor, len(text_before_cursor)) == 'backtick': + return [self.maybe_quote_identifier(x) if isinstance(x, str) else x for x in collection] + return collection + + def word_parts_match( + self, + text_parts: list[str], + item_parts: list[str], + ) -> bool: + occurrences = 0 + for text_part in text_parts: + for item_part in item_parts: + if item_part.startswith(text_part): + occurrences += 1 + break + return occurrences >= len(text_parts) + + def find_fuzzy_match( + self, + item: str, + pattern: re.Pattern[str], + under_words_text: list[str], + case_words_text: list[str], + ) -> int | None: + if pattern.search(item.lower()): + return Fuzziness.REGEX + + under_words_item = [x for x in item.lower().split('_') if x] + if self.word_parts_match(under_words_text, under_words_item): + return Fuzziness.UNDER_WORDS + + case_words_item = re.split(_CASE_CHANGE_PAT, item) + if self.word_parts_match(case_words_text, case_words_item): + return Fuzziness.CAMEL_CASE + + return None + + def find_fuzzy_matches( + self, + last: str, + text: str, + collection: Collection[Any], + ) -> list[tuple[str, int]]: + completions: list[tuple[str, int]] = [] + regex = '.{0,3}?'.join(map(re.escape, text)) + pattern = re.compile(f'({regex})') + under_words_text = [x for x in text.split('_') if x] + case_words_text = re.split(_CASE_CHANGE_PAT, last) + + for item in collection: + fuzziness = self.find_fuzzy_match(item, pattern, under_words_text, case_words_text) + if fuzziness is not None: + completions.append((item, fuzziness)) + + if len(text) >= 4: + rapidfuzz_matches = rapidfuzz.process.extract( + text, + collection, + scorer=rapidfuzz.fuzz.WRatio, + # todo: maybe make our own processor which only does case-folding + # because underscores are valuable info + processor=rapidfuzz.utils.default_process, + limit=20, + score_cutoff=75, + ) + for item, _score, _type in rapidfuzz_matches: + if len(item) < len(text) / 1.5: + continue + if item in completions: + continue + completions.append((item, Fuzziness.RAPIDFUZZ)) + + return completions + + def find_perfect_matches( + self, + text: str, + collection: Collection[Any], + start_only: bool, + ) -> list[tuple[str, int]]: + completions: list[tuple[str, int]] = [] + match_end_limit = len(text) if start_only else None + for item in collection: + match_point = item.lower().find(text, 0, match_end_limit) + if match_point >= 0: + completions.append((item, Fuzziness.PERFECT)) + return completions + + def resolve_casing( + self, + casing: str | None, + last: str, + ) -> str | None: + if casing != 'auto': + return casing + return 'lower' if last and (last[0].islower() or last[-1].islower()) else 'upper' + + def apply_casing( + self, + completions: list[tuple[str, int]], + casing: str | None, + ) -> Generator[tuple[str, int], None, None]: + if casing is None: + return (completion for completion in completions) + + def apply_case(tup: tuple[str, int]) -> tuple[str, int]: + kw, fuzziness = tup + if casing == 'upper': + return (kw.upper(), fuzziness) + return (kw.lower(), fuzziness) + + return (apply_case(completion) for completion in completions) + def find_matches( + self, orig_text: str, collection: Collection, start_only: bool = False, @@ -1195,96 +1323,17 @@ def find_matches( yields prompt_toolkit Completion instances for any matches found in the collection of available completions. """ - last = last_word(orig_text, include="most_punctuations") + last = last_word(orig_text, include='most_punctuations') text = last.lower() - # unicode support not possible without adding the regex dependency - case_change_pat = re.compile("(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])") - - completions: list[tuple[str, int]] = [] - - def maybe_quote_identifier(item: str) -> str: - if item.startswith('`'): - return item - if item == '*': - return item - return '`' + item + '`' - - # checking text.startswith() first is an optimization; is_inside_quotes() covers more cases - if text.startswith('`') or is_inside_quotes(text_before_cursor, len(text_before_cursor)) == 'backtick': - quoted_collection: Collection[Any] = [maybe_quote_identifier(x) if isinstance(x, str) else x for x in collection] - else: - quoted_collection = collection + quoted_collection = self.quote_collection_if_needed(text, collection, text_before_cursor) if fuzzy: - regex = ".{0,3}?".join(map(re.escape, text)) - pat = re.compile(f'({regex})') - under_words_text = [x for x in text.split('_') if x] - case_words_text = re.split(case_change_pat, last) - - for item in quoted_collection: - r = pat.search(item.lower()) - if r: - completions.append((item, Fuzziness.REGEX)) - continue - - under_words_item = [x for x in item.lower().split('_') if x] - occurrences = 0 - for elt_word in under_words_text: - for elt_item in under_words_item: - if elt_item.startswith(elt_word): - occurrences += 1 - break - if occurrences >= len(under_words_text): - completions.append((item, Fuzziness.UNDER_WORDS)) - continue - - case_words_item = re.split(case_change_pat, item) - occurrences = 0 - for elt_word in case_words_text: - for elt_item in case_words_item: - if elt_item.startswith(elt_word): - occurrences += 1 - break - if occurrences >= len(case_words_text): - completions.append((item, Fuzziness.CAMEL_CASE)) - continue - - if len(text) >= 4: - rapidfuzz_matches = rapidfuzz.process.extract( - text, - quoted_collection, - scorer=rapidfuzz.fuzz.WRatio, - # todo: maybe make our own processor which only does case-folding - # because underscores are valuable info - processor=rapidfuzz.utils.default_process, - limit=20, - score_cutoff=75, - ) - for elt in rapidfuzz_matches: - item, _score, _type = elt - if len(item) < len(text) / 1.5: - continue - if item in completions: - continue - completions.append((item, Fuzziness.RAPIDFUZZ)) - + completions = self.find_fuzzy_matches(last, text, quoted_collection) else: - match_end_limit = len(text) if start_only else None - for item in quoted_collection: - match_point = item.lower().find(text, 0, match_end_limit) - if match_point >= 0: - completions.append((item, Fuzziness.PERFECT)) - - if casing == "auto": - casing = "lower" if last and (last[0].islower() or last[-1].islower()) else "upper" - - def apply_case(tup: tuple[str, int]) -> tuple[str, int]: - kw, fuzziness = tup - if casing == "upper": - return (kw.upper(), fuzziness) - return (kw.lower(), fuzziness) + completions = self.find_perfect_matches(text, quoted_collection, start_only) - return (x if casing is None else apply_case(x) for x in completions) + casing = self.resolve_casing(casing, last) + return self.apply_casing(completions, casing) def get_completions( self, diff --git a/test/pytests/test_sqlcompleter_find_matches.py b/test/pytests/test_sqlcompleter_find_matches.py new file mode 100644 index 00000000..b7efb528 --- /dev/null +++ b/test/pytests/test_sqlcompleter_find_matches.py @@ -0,0 +1,338 @@ +# type: ignore + +import re + +import pytest + +import mycli.sqlcompleter +from mycli.sqlcompleter import Fuzziness, SQLCompleter + + +def collect_matches( + orig_text: str, + collection: list[str], + *, + start_only: bool = False, + fuzzy: bool = True, + casing: str | None = None, + text_before_cursor: str = '', +) -> list[tuple[str, int]]: + completer = SQLCompleter() + return list( + completer.find_matches( + orig_text, + collection, + start_only=start_only, + fuzzy=fuzzy, + casing=casing, + text_before_cursor=text_before_cursor, + ) + ) + + +@pytest.mark.parametrize( + ('item', 'expected'), + [ + ('users', '`users`'), + ('`already`', '`already`'), + ('*', '*'), + ], +) +def test_maybe_quote_identifier(item: str, expected: str) -> None: + completer = SQLCompleter() + assert completer.maybe_quote_identifier(item) == expected + + +def test_quote_collection_if_needed_quotes_when_text_starts_with_backtick() -> None: + completer = SQLCompleter() + quoted = completer.quote_collection_if_needed('`us', ['users', '*'], '') + + assert quoted == ['`users`', '*'] + + +def test_quote_collection_if_needed_quotes_when_cursor_is_inside_backticks() -> None: + completer = SQLCompleter() + quoted = completer.quote_collection_if_needed('us', ['users', '`uuid`'], 'select `us') + + assert quoted == ['`users`', '`uuid`'] + + +def test_quote_collection_if_needed_leaves_collection_unchanged_when_not_quoted() -> None: + collection = ['users', '*'] + completer = SQLCompleter() + quoted = completer.quote_collection_if_needed('us', collection, 'select us') + + assert quoted is collection + + +@pytest.mark.parametrize( + ('text_parts', 'item_parts', 'expected'), + [ + (['us', 'de', 'fu'], ['user', 'defined', 'function'], True), + (['us', 'fu'], ['user', 'defined', 'function'], True), + (['us', 'zz'], ['user', 'defined', 'function'], False), + ([], ['user', 'defined', 'function'], True), + (['us'], [], False), + ], +) +def test_word_parts_match( + text_parts: list[str], + item_parts: list[str], + expected: bool, +) -> None: + completer = SQLCompleter() + assert completer.word_parts_match(text_parts, item_parts) is expected + + +@pytest.mark.parametrize( + ('item', 'pattern', 'under_words_text', 'case_words_text', 'expected'), + [ + ('foo_select_bar', re.compile('(s.{0,3}?e.{0,3}?l)'), ['sel'], ['sel'], Fuzziness.REGEX), + ('user_defined_function', re.compile('(z.{0,3}?z)'), ['us', 'de', 'fu'], ['us_de_fu'], Fuzziness.UNDER_WORDS), + ('TimeZoneTransitionType', re.compile('(Ti.{0,3}?Zx)'), ['TiZoTrTy'], ['Ti', 'Zo', 'Tr', 'Ty'], Fuzziness.CAMEL_CASE), + ('orders', re.compile('(z.{0,3}?z)'), ['zz'], ['zz'], None), + ], +) +def test_find_fuzzy_match( + item: str, + pattern: re.Pattern[str], + under_words_text: list[str], + case_words_text: list[str], + expected: int | None, +) -> None: + completer = SQLCompleter() + assert completer.find_fuzzy_match(item, pattern, under_words_text, case_words_text) == expected + + +def test_find_fuzzy_matches_collects_item_level_matches(monkeypatch) -> None: + monkeypatch.setattr( + SQLCompleter, + 'find_fuzzy_match', + lambda self, item, pattern, under_words_text, case_words_text: { + 'orders': Fuzziness.REGEX, + 'order_items': Fuzziness.UNDER_WORDS, + 'other': None, + }[item], + ) + monkeypatch.setattr(mycli.sqlcompleter.rapidfuzz.process, 'extract', lambda *args, **kwargs: []) + completer = SQLCompleter() + matches = completer.find_fuzzy_matches('OrIt', 'orit', ['orders', 'order_items', 'other']) + + assert matches == [ + ('orders', Fuzziness.REGEX), + ('order_items', Fuzziness.UNDER_WORDS), + ] + + +def test_find_fuzzy_matches_skips_rapidfuzz_for_short_text(monkeypatch) -> None: + monkeypatch.setattr(SQLCompleter, 'find_fuzzy_match', lambda *args, **kwargs: None) + + def fail_extract(*args, **kwargs): + raise AssertionError('rapidfuzz should not be called') + + monkeypatch.setattr(mycli.sqlcompleter.rapidfuzz.process, 'extract', fail_extract) + completer = SQLCompleter() + matches = completer.find_fuzzy_matches('sel', 'sel', ['SELECT']) + + assert matches == [] + + +def test_find_fuzzy_matches_appends_rapidfuzz_results_and_keeps_current_duplicates(monkeypatch) -> None: + monkeypatch.setattr( + SQLCompleter, + 'find_fuzzy_match', + lambda self, item, pattern, under_words_text, case_words_text: Fuzziness.REGEX if item == 'alphabet' else None, + ) + monkeypatch.setattr( + mycli.sqlcompleter.rapidfuzz.process, + 'extract', + lambda *args, **kwargs: [('abc', 99, 0), ('alphabet', 95, 1), ('alphanumeric', 90, 2)], + ) + completer = SQLCompleter() + matches = completer.find_fuzzy_matches('alpahet', 'alpahet', ['abc', 'alphabet', 'alphanumeric']) + + assert matches == [ + ('alphabet', Fuzziness.REGEX), + ('alphabet', Fuzziness.RAPIDFUZZ), + ('alphanumeric', Fuzziness.RAPIDFUZZ), + ] + + +@pytest.mark.parametrize( + ('text', 'collection', 'start_only', 'expected'), + [ + ('ord', ['orders', 'user_orders'], True, [('orders', Fuzziness.PERFECT)]), + ('name', ['table_name', 'name_table'], False, [('table_name', Fuzziness.PERFECT), ('name_table', Fuzziness.PERFECT)]), + ('', ['orders', 'users'], True, [('orders', Fuzziness.PERFECT), ('users', Fuzziness.PERFECT)]), + ], +) +def test_find_perfect_matches( + text: str, + collection: list[str], + start_only: bool, + expected: list[tuple[str, int]], +) -> None: + completer = SQLCompleter() + assert completer.find_perfect_matches(text, collection, start_only) == expected + + +@pytest.mark.parametrize( + ('casing', 'last', 'expected'), + [ + (None, 'Sel', None), + ('upper', 'sel', 'upper'), + ('lower', 'SEL', 'lower'), + ('auto', 'sel', 'lower'), + ('auto', 'SEl', 'lower'), + ('auto', 'SEL', 'upper'), + ('auto', '', 'upper'), + ], +) +def test_resolve_casing(casing: str | None, last: str, expected: str | None) -> None: + completer = SQLCompleter() + assert completer.resolve_casing(casing, last) == expected + + +@pytest.mark.parametrize( + ('completions', 'casing', 'expected'), + [ + ([('Select', Fuzziness.REGEX)], None, [('Select', Fuzziness.REGEX)]), + ([('Select', Fuzziness.REGEX)], 'upper', [('SELECT', Fuzziness.REGEX)]), + ([('Select', Fuzziness.REGEX)], 'lower', [('select', Fuzziness.REGEX)]), + ( + [('Select', Fuzziness.REGEX), ('From', Fuzziness.PERFECT)], + 'upper', + [('SELECT', Fuzziness.REGEX), ('FROM', Fuzziness.PERFECT)], + ), + ], +) +def test_apply_casing( + completions: list[tuple[str, int]], + casing: str | None, + expected: list[tuple[str, int]], +) -> None: + completer = SQLCompleter() + assert list(completer.apply_casing(completions, casing)) == expected + + +def test_find_matches_uses_last_word_for_prefix_matching() -> None: + matches = collect_matches( + 'select ord', + ['orders', 'user_orders'], + start_only=True, + fuzzy=False, + ) + + assert matches == [('orders', Fuzziness.PERFECT)] + + +def test_find_matches_supports_substring_matching() -> None: + matches = collect_matches( + 'name', + ['table_name', 'name_table'], + start_only=False, + fuzzy=False, + ) + + assert matches == [ + ('table_name', Fuzziness.PERFECT), + ('name_table', Fuzziness.PERFECT), + ] + + +def test_find_matches_quotes_identifiers_when_text_starts_with_backtick() -> None: + matches = collect_matches('`us', ['users']) + + assert matches == [('`users`', Fuzziness.REGEX)] + + +def test_find_matches_quotes_identifiers_when_cursor_is_inside_backticks() -> None: + matches = collect_matches( + 'uu', + ['users', '`uuid`'], + text_before_cursor='select `uu', + ) + + assert matches == [('`uuid`', Fuzziness.REGEX)] + + +def test_find_matches_preserves_asterisk_inside_backticks() -> None: + matches = collect_matches( + '*', + ['*'], + text_before_cursor='select `*', + ) + + assert matches == [('*', Fuzziness.REGEX)] + + +def test_find_matches_finds_regex_matches() -> None: + matches = collect_matches('sel', ['SELECT', 'foo_select_bar']) + + assert matches == [ + ('SELECT', Fuzziness.REGEX), + ('foo_select_bar', Fuzziness.REGEX), + ] + + +def test_find_matches_finds_under_word_matches() -> None: + matches = collect_matches('us_de_fu', ['user_defined_function']) + + assert matches == [('user_defined_function', Fuzziness.UNDER_WORDS)] + + +def test_find_matches_finds_camel_case_matches(monkeypatch) -> None: + monkeypatch.setattr(mycli.sqlcompleter.rapidfuzz.process, 'extract', lambda *args, **kwargs: []) + + matches = collect_matches('TiZoTrTy', ['TimeZoneTransitionType']) + + assert matches == [('TimeZoneTransitionType', Fuzziness.CAMEL_CASE)] + + +def test_find_matches_finds_rapidfuzz_matches() -> None: + matches = collect_matches('sleect', ['SELECT']) + + assert matches == [('SELECT', Fuzziness.RAPIDFUZZ)] + + +def test_find_matches_skips_rapidfuzz_for_short_text(monkeypatch) -> None: + def fail_extract(*args, **kwargs): + raise AssertionError('rapidfuzz should not be called') + + monkeypatch.setattr(mycli.sqlcompleter.rapidfuzz.process, 'extract', fail_extract) + + matches = collect_matches('sel', ['SELECT']) + + assert matches == [('SELECT', Fuzziness.REGEX)] + + +def test_find_matches_filters_short_rapidfuzz_candidates(monkeypatch) -> None: + monkeypatch.setattr( + mycli.sqlcompleter.rapidfuzz.process, + 'extract', + lambda *args, **kwargs: [('abc', 99, 0), ('alphabet', 95, 1)], + ) + + matches = collect_matches('alpahet', ['abc', 'alphabet']) + + assert matches == [('alphabet', Fuzziness.RAPIDFUZZ)] + + +@pytest.mark.parametrize( + ('orig_text', 'collection', 'casing', 'expected'), + [ + ('sel', ['SELECT'], 'auto', [('select', Fuzziness.REGEX)]), + ('SEL', ['select'], 'auto', [('SELECT', Fuzziness.REGEX)]), + ('sel', ['select'], 'upper', [('SELECT', Fuzziness.REGEX)]), + ('SEL', ['SELECT'], 'lower', [('select', Fuzziness.REGEX)]), + ], +) +def test_find_matches_applies_casing( + orig_text: str, + collection: list[str], + casing: str, + expected: list[tuple[str, int]], +) -> None: + matches = collect_matches(orig_text, collection, casing=casing) + + assert matches == expected