Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggest objects from schemas not in search_path #649

Merged
merged 4 commits into from
Mar 14, 2017
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions pgcli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(self, force_passwd_prompt=False, never_passwd_prompt=False,
'generate_aliases': c['main'].as_bool('generate_aliases'),
'asterisk_column_order': c['main']['asterisk_column_order'],
'qualify_columns': c['main']['qualify_columns'],
'search_path_filter': c['main'].as_bool('search_path_filter'),
'single_connection': single_connection,
'keyword_casing': keyword_casing,
}
Expand Down
3 changes: 3 additions & 0 deletions pgcli/pgclirc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ asterisk_column_order = table_order
# Possible values: "always", never" and "if_more_than_one_table"
qualify_columns = if_more_than_one_table

# When no schema is entered, only suggest objects in search_path
search_path_filter = False

# Default pager.
# By default 'PAGER' environment variable is used
# pager = less -SRXF
Expand Down
86 changes: 47 additions & 39 deletions pgcli/pgcompleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
NamedQueries.instance = NamedQueries.from_config(
load_config(config_location() + 'config'))


Match = namedtuple('Match', ['completion', 'priority'])
_SchemaObject = namedtuple('SchemaObject', ['name', 'schema', 'function'])
def SchemaObject(name, schema=None, function=False):
return _SchemaObject(name, schema, function)

_Candidate = namedtuple('Candidate', ['completion', 'priority', 'meta', 'synonyms'])
def Candidate(completion, priority=None, meta=None, synonyms=None):
Expand Down Expand Up @@ -57,6 +59,7 @@ def __init__(self, smart_completion=True, pgspecial=None, settings=None):
self.pgspecial = pgspecial
self.prioritizer = PrevalenceCounter()
settings = settings or {}
self.search_path_filter = settings.get('search_path_filter')
self.generate_aliases = settings.get('generate_aliases')
self.casing_file = settings.get('casing_file')
self.generate_casing_file = settings.get('generate_casing_file')
Expand Down Expand Up @@ -550,8 +553,8 @@ def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
return self.find_matches(word_before_cursor, conds, meta='join')

def get_function_matches(self, suggestion, word_before_cursor, alias=False):
def _cand(func_name, alias):
return self._make_cand(func_name, alias, suggestion, function=True)
def _cand(func, alias):
return self._make_cand(func, alias, suggestion)
if suggestion.filter == 'for_from_clause':
# Only suggest functions allowed in FROM clause
filt = lambda f: not f.is_aggregate and not f.is_window
Expand Down Expand Up @@ -597,24 +600,26 @@ def get_from_clause_item_matches(self, suggestion, word_before_cursor):
+ self.get_view_matches(v_sug, word_before_cursor, alias)
+ self.get_function_matches(f_sug, word_before_cursor, alias))

def _make_cand(self, tbl, do_alias, suggestion, function=False):
cased_tbl = self.case(tbl)
alias = self.alias(cased_tbl, suggestion.table_refs)
def _make_cand(self, tbl, do_alias, suggestion):
cased_tbl = self.case(tbl.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment here that tbl is a SchemaObject?

if do_alias:
alias = self.alias(cased_tbl, suggestion.table_refs)
synonyms = (cased_tbl, generate_alias(cased_tbl))
maybe_parens = '()' if function else ''
maybe_parens = '()' if tbl.function else ''
maybe_alias = (' ' + alias) if do_alias else ''
item = cased_tbl + maybe_parens + maybe_alias
maybe_schema = (self.case(tbl.schema) + '.') if tbl.schema else ''
item = maybe_schema + cased_tbl + maybe_parens + maybe_alias
return Candidate(item, synonyms=synonyms)

def get_table_matches(self, suggestion, word_before_cursor, alias=False):
tables = self.populate_schema_objects(suggestion.schema, 'tables')
tables.extend(tbl.name for tbl in suggestion.local_tables)
tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables)

# Unless we're sure the user really wants them, don't suggest the
# pg_catalog tables that are implicitly on the search path
if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
tables = [t for t in tables if not t.startswith('pg_')]
tables = [t for t in tables if not t.name.startswith('pg_')]
tables = [self._make_cand(t, alias, suggestion) for t in tables]
return self.find_matches(word_before_cursor, tables, meta='table')

Expand All @@ -624,7 +629,7 @@ def get_view_matches(self, suggestion, word_before_cursor, alias=False):

if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
views = [v for v in views if not v.startswith('pg_')]
views = [v for v in views if not v.name.startswith('pg_')]
views = [self._make_cand(v, alias, suggestion) for v in views]
return self.find_matches(word_before_cursor, views, meta='view')

Expand Down Expand Up @@ -672,6 +677,7 @@ def get_special_matches(self, _, word_before_cursor):
def get_datatype_matches(self, suggestion, word_before_cursor):
# suggest custom datatypes
types = self.populate_schema_objects(suggestion.schema, 'datatypes')
types = [self._make_cand(t, False, suggestion) for t in types]
matches = self.find_matches(word_before_cursor, types, meta='datatype')

if not suggestion.schema:
Expand Down Expand Up @@ -747,22 +753,28 @@ def addcols(schema, rel, alias, reltype, cols):

return columns

def _get_schemas(self, obj_typ, schema):
metadata = self.dbmetadata[obj_typ]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A docstring here would be nice. like "Return a list of schemas from which SchemaObjects can be suggested", and note that schema is the optional user-specified schema qualification?

if schema:
schema = self.escape_name(schema)
return [schema] if schema in metadata else []
return self.search_path if self.search_path_filter else metadata.keys()

def _maybe_schema(self, schema, parent):
return None if parent or schema in self.search_path else schema

def populate_schema_objects(self, schema, obj_type):
"""Returns list of tables or functions for a (optional) schema"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be changed to Returns a list of SchemaObjects...


metadata = self.dbmetadata[obj_type]
if schema:
try:
objects = metadata[self.escape_name(schema)].keys()
except KeyError:
# schema doesn't exist
objects = []
else:
schemas = self.search_path
objects = [obj for schema in schemas
for obj in metadata[schema].keys()]

return [self.case(o) for o in objects]
return [
SchemaObject(
name=obj,
schema=(self._maybe_schema(schema=sch, parent=schema)),
function=(obj_type == 'functions')
)
for sch in self._get_schemas(obj_type, schema)
for obj in self.dbmetadata[obj_type][sch].keys()
]

def populate_functions(self, schema, filter_func):
"""Returns a list of function names
Expand All @@ -772,24 +784,20 @@ def populate_functions(self, schema, filter_func):
kept or discarded
"""

metadata = self.dbmetadata['functions']

# Because of multiple dispatch, we can have multiple functions
# with the same name, which is why `for meta in metas` is necessary
# in the comprehensions below
if schema:
schema = self.escape_name(schema)
try:
return [func for (func, metas) in metadata[schema].items()
for meta in metas
if filter_func(meta)]
except KeyError:
return []
else:
return [func for schema in self.search_path
for (func, metas) in metadata[schema].items()
for meta in metas
if filter_func(meta)]
return [
SchemaObject(
name=func,
schema=(self._maybe_schema(schema=sch, parent=schema)),
function=True
)
for sch in self._get_schemas('functions', schema)
for (func, metas) in self.dbmetadata['functions'][sch].items()
for meta in metas
if filter_func(meta)
]



76 changes: 71 additions & 5 deletions tests/test_smart_completion_multiple_schemata.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,43 @@

@pytest.fixture
def completer():
return testdata.completer
return testdata.get_completer(settings={'search_path_filter': True})

casing = ('SELECT', 'Orders', 'User_Emails', 'CUSTOM', 'Func1', 'Entries',
'Tags', 'EntryTags', 'EntAccLog',
'EntryID', 'EntryTitle', 'EntryText')

@pytest.fixture
def completer_with_casing():
return testdata.get_completer(casing=casing)
return testdata.get_completer(
settings={'search_path_filter': True},
casing=casing
)

@pytest.fixture
def completer_with_aliases():
return testdata.get_completer({'generate_aliases': True})
return testdata.get_completer(
settings={'generate_aliases': True, 'search_path_filter': True}
)

@pytest.fixture
def completer_aliases_casing():
return testdata.get_completer(
settings={'generate_aliases': True, 'search_path_filter': True},
casing=casing
)

@pytest.fixture
def completer_aliases_casing(request):
return testdata.get_completer({'generate_aliases': True}, casing)
def completer_all_schemas():
return testdata.get_completer()

@pytest.fixture
def completer_all_schemas_casing():
return testdata.get_completer(casing=casing)

@pytest.fixture
def completer_all_schemas_aliases():
return testdata.get_completer(settings={'generate_aliases': True})

@pytest.fixture
def complete_event():
Expand Down Expand Up @@ -598,3 +618,49 @@ def test_column_alias_search_qualified(completer_aliases_casing,
Document(text, cursor_position=len('SELECT E.ei')), complete_event)
cols = ('EntryID', 'EntryTitle')
assert result[:3] == [column(c, -2) for c in cols]

def test_all_schema_objects(completer_all_schemas, complete_event):
text = 'SELECT * FROM '
position = len('SELECT * FROM ')
result = set(
completer_all_schemas.get_completions(
Document(text=text, cursor_position=position),
complete_event
)
)
assert result >= set(
[table(x) for x in ('orders', '"select"', 'custom.shipments')]
+ [function(x+'()') for x in ('func2', 'custom.func3')]
)

def test_all_schema_objects_with_casing(
completer_all_schemas_casing, complete_event
):
text = 'SELECT * FROM '
position = len('SELECT * FROM ')
result = set(
completer_all_schemas_casing.get_completions(
Document(text=text, cursor_position=position),
complete_event
)
)
assert result >= set(
[table(x) for x in ('Orders', '"select"', 'CUSTOM.shipments')]
+ [function(x+'()') for x in ('func2', 'CUSTOM.func3')]
)

def test_all_schema_objects_with_aliases(
completer_all_schemas_aliases, complete_event
):
text = 'SELECT * FROM '
position = len('SELECT * FROM ')
result = set(
completer_all_schemas_aliases.get_completions(
Document(text=text, cursor_position=position),
complete_event
)
)
assert result >= set(
[table(x) for x in ('orders o', '"select" s', 'custom.shipments s')]
+ [function(x) for x in ('func2() f', 'custom.func3() f')]
)