Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggest objects from schemas not in search_path #649

Merged
merged 4 commits into from
Mar 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions pgcli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False,
'generate_aliases': c['main'].as_bool('generate_aliases'),
'asterisk_column_order': c['main']['asterisk_column_order'],
'qualify_columns': c['main']['qualify_columns'],
'search_path_filter': c['main'].as_bool('search_path_filter'),
'single_connection': single_connection,
'keyword_casing': keyword_casing,
}
Expand Down
3 changes: 3 additions & 0 deletions pgcli/pgclirc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ asterisk_column_order = table_order
# Possible values: "always", never" and "if_more_than_one_table"
qualify_columns = if_more_than_one_table

# When no schema is entered, only suggest objects in search_path
search_path_filter = False

# Default pager.
# By default 'PAGER' environment variable is used
# pager = less -SRXF
Expand Down
114 changes: 67 additions & 47 deletions pgcli/pgcompleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@
NamedQueries.instance = NamedQueries.from_config(
load_config(config_location() + 'config'))


Match = namedtuple('Match', ['completion', 'priority'])
_SchemaObject = namedtuple('SchemaObject', ['name', 'schema', 'function'])
def SchemaObject(name, schema=None, function=False):
return _SchemaObject(name, schema, function)

_Candidate = namedtuple('Candidate', ['completion', 'priority', 'meta', 'synonyms'])
def Candidate(completion, priority=None, meta=None, synonyms=None):
return _Candidate(completion, priority, meta, synonyms or [completion])
_Candidate = namedtuple(
'Candidate', ['completion', 'prio', 'meta', 'synonyms', 'prio2']
)
def Candidate(completion, prio=None, meta=None, synonyms=None, prio2=None):
return _Candidate(completion, prio, meta, synonyms or [completion], prio2)

normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'

Expand All @@ -57,6 +61,7 @@ def __init__(self, smart_completion=True, pgspecial=None, settings=None):
self.pgspecial = pgspecial
self.prioritizer = PrevalenceCounter()
settings = settings or {}
self.search_path_filter = settings.get('search_path_filter')
self.generate_aliases = settings.get('generate_aliases')
self.casing_file = settings.get('casing_file')
self.generate_casing_file = settings.get('generate_casing_file')
Expand Down Expand Up @@ -315,15 +320,15 @@ def _match(item):
matches = []
for cand in collection:
if isinstance(cand, _Candidate):
item, prio, display_meta, synonyms = cand
item, prio, display_meta, synonyms, prio2 = cand
if display_meta is None:
display_meta = meta
syn_matches = (_match(x) for x in synonyms)
# Nones need to be removed to avoid max() crashing in Python 3
syn_matches = [m for m in syn_matches if m]
sort_key = max(syn_matches) if syn_matches else None
else:
item, display_meta, prio = cand, meta, 0
item, display_meta, prio, prio2 = cand, meta, 0, 0
sort_key = _match(cand)

if sort_key:
Expand All @@ -345,7 +350,10 @@ def _match(item):
+ tuple(c for c in item))

item = self.case(item)
priority = sort_key, type_priority, prio, priority_func(item), lexical_priority
priority = (
sort_key, type_priority, prio, priority_func(item),
prio2, lexical_priority
)

matches.append(Match(
completion=Completion(item, -text_len,
Expand Down Expand Up @@ -550,8 +558,8 @@ def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
return self.find_matches(word_before_cursor, conds, meta='join')

def get_function_matches(self, suggestion, word_before_cursor, alias=False):
def _cand(func_name, alias):
return self._make_cand(func_name, alias, suggestion, function=True)
def _cand(func, alias):
return self._make_cand(func, alias, suggestion)
if suggestion.filter == 'for_from_clause':
# Only suggest functions allowed in FROM clause
filt = lambda f: not f.is_aggregate and not f.is_window
Expand Down Expand Up @@ -597,24 +605,28 @@ def get_from_clause_item_matches(self, suggestion, word_before_cursor):
+ self.get_view_matches(v_sug, word_before_cursor, alias)
+ self.get_function_matches(f_sug, word_before_cursor, alias))

def _make_cand(self, tbl, do_alias, suggestion, function=False):
cased_tbl = self.case(tbl)
alias = self.alias(cased_tbl, suggestion.table_refs)
# Note: tbl is a SchemaObject
def _make_cand(self, tbl, do_alias, suggestion):
cased_tbl = self.case(tbl.name)
if do_alias:
alias = self.alias(cased_tbl, suggestion.table_refs)
synonyms = (cased_tbl, generate_alias(cased_tbl))
maybe_parens = '()' if function else ''
maybe_parens = '()' if tbl.function else ''
maybe_alias = (' ' + alias) if do_alias else ''
item = cased_tbl + maybe_parens + maybe_alias
return Candidate(item, synonyms=synonyms)
maybe_schema = (self.case(tbl.schema) + '.') if tbl.schema else ''
item = maybe_schema + cased_tbl + maybe_parens + maybe_alias
prio2 = 0 if tbl.schema else 1
return Candidate(item, synonyms=synonyms, prio2=prio2)

def get_table_matches(self, suggestion, word_before_cursor, alias=False):
tables = self.populate_schema_objects(suggestion.schema, 'tables')
tables.extend(tbl.name for tbl in suggestion.local_tables)
tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables)

# Unless we're sure the user really wants them, don't suggest the
# pg_catalog tables that are implicitly on the search path
if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
tables = [t for t in tables if not t.startswith('pg_')]
tables = [t for t in tables if not t.name.startswith('pg_')]
tables = [self._make_cand(t, alias, suggestion) for t in tables]
return self.find_matches(word_before_cursor, tables, meta='table')

Expand All @@ -624,7 +636,7 @@ def get_view_matches(self, suggestion, word_before_cursor, alias=False):

if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
views = [v for v in views if not v.startswith('pg_')]
views = [v for v in views if not v.name.startswith('pg_')]
views = [self._make_cand(v, alias, suggestion) for v in views]
return self.find_matches(word_before_cursor, views, meta='view')

Expand Down Expand Up @@ -672,6 +684,7 @@ def get_special_matches(self, _, word_before_cursor):
def get_datatype_matches(self, suggestion, word_before_cursor):
# suggest custom datatypes
types = self.populate_schema_objects(suggestion.schema, 'datatypes')
types = [self._make_cand(t, False, suggestion) for t in types]
matches = self.find_matches(word_before_cursor, types, meta='datatype')

if not suggestion.schema:
Expand Down Expand Up @@ -747,22 +760,33 @@ def addcols(schema, rel, alias, reltype, cols):

return columns

def populate_schema_objects(self, schema, obj_type):
"""Returns list of tables or functions for a (optional) schema"""

metadata = self.dbmetadata[obj_type]
def _get_schemas(self, obj_typ, schema):
""" Returns a list of schemas from which to suggest objects
schema is the schema qualification input by the user (if any)
"""
metadata = self.dbmetadata[obj_typ]
if schema:
try:
objects = metadata[self.escape_name(schema)].keys()
except KeyError:
# schema doesn't exist
objects = []
else:
schemas = self.search_path
objects = [obj for schema in schemas
for obj in metadata[schema].keys()]
schema = self.escape_name(schema)
return [schema] if schema in metadata else []
return self.search_path if self.search_path_filter else metadata.keys()

def _maybe_schema(self, schema, parent):
return None if parent or schema in self.search_path else schema

return [self.case(o) for o in objects]
def populate_schema_objects(self, schema, obj_type):
"""Returns a list of SchemaObjects representing tables, views, funcs
schema is the schema qualification input by the user (if any)
"""

return [
SchemaObject(
name=obj,
schema=(self._maybe_schema(schema=sch, parent=schema)),
function=(obj_type == 'functions')
)
for sch in self._get_schemas(obj_type, schema)
for obj in self.dbmetadata[obj_type][sch].keys()
]

def populate_functions(self, schema, filter_func):
"""Returns a list of function names
Expand All @@ -772,24 +796,20 @@ def populate_functions(self, schema, filter_func):
kept or discarded
"""

metadata = self.dbmetadata['functions']

# Because of multiple dispatch, we can have multiple functions
# with the same name, which is why `for meta in metas` is necessary
# in the comprehensions below
if schema:
schema = self.escape_name(schema)
try:
return [func for (func, metas) in metadata[schema].items()
for meta in metas
if filter_func(meta)]
except KeyError:
return []
else:
return [func for schema in self.search_path
for (func, metas) in metadata[schema].items()
for meta in metas
if filter_func(meta)]
return [
SchemaObject(
name=func,
schema=(self._maybe_schema(schema=sch, parent=schema)),
function=True
)
for sch in self._get_schemas('functions', schema)
for (func, metas) in self.dbmetadata['functions'][sch].items()
for meta in metas
if filter_func(meta)
]



87 changes: 82 additions & 5 deletions tests/test_smart_completion_multiple_schemata.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,43 @@

@pytest.fixture
def completer():
return testdata.completer
return testdata.get_completer(settings={'search_path_filter': True})

casing = ('SELECT', 'Orders', 'User_Emails', 'CUSTOM', 'Func1', 'Entries',
'Tags', 'EntryTags', 'EntAccLog',
'EntryID', 'EntryTitle', 'EntryText')

@pytest.fixture
def completer_with_casing():
return testdata.get_completer(casing=casing)
return testdata.get_completer(
settings={'search_path_filter': True},
casing=casing
)

@pytest.fixture
def completer_with_aliases():
return testdata.get_completer({'generate_aliases': True})
return testdata.get_completer(
settings={'generate_aliases': True, 'search_path_filter': True}
)

@pytest.fixture
def completer_aliases_casing():
return testdata.get_completer(
settings={'generate_aliases': True, 'search_path_filter': True},
casing=casing
)

@pytest.fixture
def completer_all_schemas():
return testdata.get_completer()

@pytest.fixture
def completer_aliases_casing(request):
return testdata.get_completer({'generate_aliases': True}, casing)
def completer_all_schemas_casing():
return testdata.get_completer(casing=casing)

@pytest.fixture
def completer_all_schemas_aliases():
return testdata.get_completer(settings={'generate_aliases': True})

@pytest.fixture
def complete_event():
Expand Down Expand Up @@ -598,3 +618,60 @@ def test_column_alias_search_qualified(completer_aliases_casing,
Document(text, cursor_position=len('SELECT E.ei')), complete_event)
cols = ('EntryID', 'EntryTitle')
assert result[:3] == [column(c, -2) for c in cols]

def test_schema_object_order(completer_all_schemas, complete_event):
text = 'SELECT * FROM u'
position = len('SELECT * FROM u')
result = completer_all_schemas.get_completions(
Document(text=text, cursor_position=position),
complete_event
)
assert result[:3] == [
table(t, pos=-1) for t in ('users', 'custom."Users"', 'custom.users')
]

def test_all_schema_objects(completer_all_schemas, complete_event):
text = 'SELECT * FROM '
position = len('SELECT * FROM ')
result = set(
completer_all_schemas.get_completions(
Document(text=text, cursor_position=position),
complete_event
)
)
assert result >= set(
[table(x) for x in ('orders', '"select"', 'custom.shipments')]
+ [function(x+'()') for x in ('func2', 'custom.func3')]
)

def test_all_schema_objects_with_casing(
completer_all_schemas_casing, complete_event
):
text = 'SELECT * FROM '
position = len('SELECT * FROM ')
result = set(
completer_all_schemas_casing.get_completions(
Document(text=text, cursor_position=position),
complete_event
)
)
assert result >= set(
[table(x) for x in ('Orders', '"select"', 'CUSTOM.shipments')]
+ [function(x+'()') for x in ('func2', 'CUSTOM.func3')]
)

def test_all_schema_objects_with_aliases(
completer_all_schemas_aliases, complete_event
):
text = 'SELECT * FROM '
position = len('SELECT * FROM ')
result = set(
completer_all_schemas_aliases.get_completions(
Document(text=text, cursor_position=position),
complete_event
)
)
assert result >= set(
[table(x) for x in ('orders o', '"select" s', 'custom.shipments s')]
+ [function(x) for x in ('func2() f', 'custom.func3() f')]
)