Skip to content

Commit

Permalink
Merge 7c7e390 into ebf0a03
Browse files Browse the repository at this point in the history
  • Loading branch information
mfussenegger committed Aug 1, 2019
2 parents ebf0a03 + 7c7e390 commit 9bd9ce3
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 105 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Expand Up @@ -2,6 +2,7 @@
.installed.cfg
.tox/
.venv/
venv/
*.DS_Store
*.egg-info
*.iml
Expand All @@ -14,3 +15,5 @@ dist/
eggs/
out/
parts/
.mypy_cache/

3 changes: 2 additions & 1 deletion .travis.yml
Expand Up @@ -14,7 +14,8 @@ before_install:
- pip install setuptools==33.1.1 # pin specific version in virtualenv
- pip freeze --all # debug
- sudo add-apt-repository ppa:openjdk-r/ppa -y
- sudo apt-get install -y openjdk-11-jre-headless
- sudo apt-get update
- sudo apt-get install -y openjdk-12-jdk

install:
- python bootstrap.py
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Expand Up @@ -7,3 +7,6 @@ exclude = tabulate.py

[pycodestyle]
ignore = E501

[pydocstyle]
ignore = D107, D203, D213
148 changes: 70 additions & 78 deletions src/crate/crash/command.py
Expand Up @@ -225,8 +225,8 @@ def __init__(self,
self.sys_info_cmd = SysInfoCommand(self)
self.commands = {
'q': self._quit,
'c': self._connect,
'connect': self._connect,
'c': self._connect_and_print_result,
'connect': self._connect_and_print_result,
'dt': self._show_tables,
'sysinfo': self.sys_info_cmd.execute,
}
Expand All @@ -249,7 +249,7 @@ def __init__(self,
# establish connection
self.cursor = None
self.connection = None
self._do_connect(crate_hosts)
self._connect(crate_hosts)

def __enter__(self):
return self
Expand Down Expand Up @@ -277,7 +277,7 @@ def pprint(self, rows, cols):
def process_iterable(self, stdin):
any_statement = False
for statement in _parse_statements(stdin):
self._exec(statement)
self._exec_and_print(statement)
any_statement = True
return any_statement

Expand All @@ -286,7 +286,7 @@ def process(self, text):
self._try_exec_cmd(text.lstrip('\\'))
else:
for statement in _parse_statements([text]):
self._exec(statement)
self._exec_and_print(statement)

def exit(self):
self.close()
Expand Down Expand Up @@ -314,11 +314,11 @@ def _show_tables(self, *args):
table_filter = \
" AND table_type = 'BASE TABLE'" if v >= TABLE_TYPE_MIN_VERSION else ""

self._exec("SELECT format('%s.%s', {schema}, table_name) AS name "
"FROM information_schema.tables "
"WHERE {schema} NOT IN ('sys','information_schema', 'pg_catalog')"
"{table_filter}"
.format(schema=schema_name, table_filter=table_filter))
self._exec_and_print(
"SELECT format('%s.%s', {schema}, table_name) AS name "
"FROM information_schema.tables "
"WHERE {schema} NOT IN ('sys','information_schema', 'pg_catalog')"
"{table_filter}".format(schema=schema_name, table_filter=table_filter))

@noargs_command
def _quit(self, *args):
Expand All @@ -330,7 +330,7 @@ def is_conn_available(self):
return self.connection and \
self.connection.lowest_server_version != StrictVersion("0.0.0")

def _do_connect(self, servers):
def _connect(self, servers):
self.last_connected_servers = servers
if self.cursor or self.connection:
self.close() # reset open cursor and connection
Expand All @@ -347,16 +347,16 @@ def _do_connect(self, servers):
self.cursor = self.connection.cursor()
self._fetch_session_info()

def _connect(self, servers):
def _connect_and_print_result(self, servers):
""" connect to the given server, e.g.: \\connect localhost:4200 """
self._do_connect(servers.split(' '))
self._verify_connection(verbose=True)
self._connect(servers.split(' '))
self._print_connect_result(verbose=True)

def reconnect(self):
"""Connect with same configuration and to last connected servers"""
self._do_connect(self.last_connected_servers)
"""Connect with same configuration and to last connected servers."""
self._connect(self.last_connected_servers)

def _verify_connection(self, verbose=False):
def _get_server_information(self):
results = []
failed = 0
client = self.connection.client
Expand All @@ -372,7 +372,10 @@ def _verify_connection(self, verbose=False):
# sort by CONNECTED DESC, SERVER_URL
results.sort(key=itemgetter(3), reverse=True)
results.sort(key=itemgetter(0))
return results, failed

def _print_connect_result(self, verbose=False):
results, failed = self._get_server_information()
if verbose:
cols = ['server_url', 'node_name', 'version', 'connected', 'message']
self.pprint(results, cols)
Expand All @@ -391,27 +394,18 @@ def _verify_connection(self, verbose=False):
def _fetch_session_info(self):
if self.is_conn_available() \
and self.connection.lowest_server_version >= StrictVersion("2.0"):
user, schema = self._user_and_schema()

try:
self.cursor.execute('SELECT current_user, current_schema')
except ProgrammingError:
# current_user is only available in the enterprise edition
self.cursor.execute('SELECT NULL, current_schema')

user, schema = self.cursor.fetchone()
self.connect_info = ConnectionMeta(user, schema)
else:
self.connect_info = ConnectionMeta(None, None)

def _user_and_schema(self):
try:
# CURRENT_USER function is only available in Enterprise Edition.
self.cursor.execute("""
SELECT
current_user AS "user",
current_schema AS "schema";
""")
except ProgrammingError:
self.cursor.execute("""
SELECT
NULL AS "user",
current_schema AS "schema";
""")
return self.cursor.fetchone()

def _try_exec_cmd(self, line):
words = line.split(' ', 1)
if not words or not words[0]:
Expand Down Expand Up @@ -445,11 +439,8 @@ def _try_exec_cmd(self, line):
'Unknown command. Type \\? for a full list of available commands.')
return False

def _exec(self, line):
success = self.execute(line)
self.exit_code = self.exit_code or int(not success)

def _execute(self, statement):
def _exec(self, statement: str) -> bool:
"""Execute the statement, prints errors if any occurr but no results."""
try:
self.cursor.execute(statement)
return True
Expand All @@ -464,8 +455,10 @@ def _execute(self, statement):
self.logger.critical('\n' + e.error_trace)
return False

def execute(self, statement):
success = self._execute(statement)
def _exec_and_print(self, statement: str) -> bool:
"""Execute the statement and print the output."""
success = self._exec(statement)
self.exit_code = self.exit_code or int(not success)
if not success:
return False
cur = self.cursor
Expand All @@ -488,9 +481,7 @@ def execute(self, statement):


def stmt_type(statement):
"""
Extract type of statement, e.g. SELECT, INSERT, UPDATE, DELETE, ...
"""
"""Extract type of statement, e.g. SELECT, INSERT, UPDATE, DELETE, ..."""
return re.findall(r'[\w]+', statement)[0].upper()


Expand Down Expand Up @@ -536,20 +527,30 @@ def get_information_schema_query(lowest_server_version):
return information_schema_query.format(schema=schema_name)


def main():
is_tty = sys.stdout.isatty()
printer = ColorPrinter(is_tty)
output_writer = OutputWriter(PrintWrapper(), is_tty)

def _load_conf(printer, formats) -> Configuration:
config = parse_config_path()
conf = None
try:
conf = Configuration(config)
return Configuration(config)
except ConfigurationError as e:
printer.warn(str(e))
parser = get_parser(output_writer.formats)
parser = get_parser(formats)
parser.print_usage()
sys.exit(1)


def _resolve_password(is_tty, force_passwd_prompt):
if force_passwd_prompt and is_tty:
return getpass()
elif not force_passwd_prompt:
return os.environ.get('CRATEPW', None)


def main():
is_tty = sys.stdout.isatty()
printer = ColorPrinter(is_tty)
output_writer = OutputWriter(PrintWrapper(), is_tty)

conf = _load_conf(printer, output_writer.formats)
parser = get_parser(output_writer.formats, conf=conf)
try:
args = parser.parse_args()
Expand All @@ -565,18 +566,7 @@ def main():
crate_hosts = [host_and_port(h) for h in args.hosts]
error_trace = args.verbose > 0

force_passwd_prompt = args.force_passwd_prompt
password = None

# If password prompt is not forced try to get it from env. variable.
if not force_passwd_prompt:
password = os.environ.get('CRATEPW', None)

# Prompt for password immediately to avoid that the first time trying to
# connect to the server runs into an `Unauthorized` excpetion
# is_tty = False
if force_passwd_prompt and not password and is_tty:
password = getpass()
password = _resolve_password(is_tty, args.force_passwd_prompt)

# Tries to create a connection to the server.
# Prompts for the password automatically if the server only accepts
Expand All @@ -586,7 +576,7 @@ def main():
cmd = _create_shell(crate_hosts, error_trace, output_writer, is_tty,
args, password=password)
except (ProgrammingError, LocationParseError) as e:
if '401' in e.message and not force_passwd_prompt:
if '401' in e.message and not args.force_passwd_prompt:
if is_tty:
password = getpass()
try:
Expand All @@ -601,27 +591,29 @@ def main():
printer.warn(str(e))
sys.exit(1)

cmd._verify_connection(verbose=error_trace)
cmd._print_connect_result(verbose=error_trace)
if not cmd.is_conn_available():
sys.exit(1)

done = False
stdin_data = get_stdin()
def save_and_exit():
conf.save()
sys.exit(cmd.exit())

if args.sysinfo:
cmd.output_writer.output_format = 'mixed'
cmd.sys_info_cmd.execute()
done = True
save_and_exit()

if args.command:
cmd.process(args.command)
done = True
elif stdin_data:
if cmd.process_iterable(stdin_data):
done = True
if not done:
from .repl import loop
loop(cmd, args.history)
conf.save()
sys.exit(cmd.exit())
save_and_exit()

if cmd.process_iterable(get_stdin()):
save_and_exit()

from .repl import loop
loop(cmd, args.history)
save_and_exit()


def _create_shell(crate_hosts, error_trace, output_writer, is_tty, args,
Expand Down
2 changes: 1 addition & 1 deletion src/crate/crash/commands.py
Expand Up @@ -132,7 +132,7 @@ class CheckBaseCommand(Command):
check_name = None

def execute(self, cmd, stmt):
success = cmd._execute(stmt)
success = cmd._exec(stmt)
cmd.exit_code = cmd.exit_code or int(not success)
if not success:
return False
Expand Down
6 changes: 3 additions & 3 deletions src/crate/crash/sysinfo.py
Expand Up @@ -29,7 +29,7 @@
SYSINFO_MIN_VERSION = StrictVersion("0.54.0")


class SysInfoCommand(object):
class SysInfoCommand:

CLUSTER_INFO = {
'shards_query': """
Expand Down Expand Up @@ -99,7 +99,7 @@ def _cluster_info(self, result):
cols = []

for query in SysInfoCommand.CLUSTER_INFO:
success = self.cmd._execute(SysInfoCommand.CLUSTER_INFO[query])
success = self.cmd._exec(SysInfoCommand.CLUSTER_INFO[query])
if success is False:
return success
rows.extend(self.cmd.cursor.fetchall()[0])
Expand All @@ -108,7 +108,7 @@ def _cluster_info(self, result):
return True

def _nodes_info(self, result):
success = self.cmd._execute(SysInfoCommand.NODES_INFO[0])
success = self.cmd._exec(SysInfoCommand.NODES_INFO[0])
if success:
result.append(Result(self.cmd.cursor.fetchall(),
[c[0] for c in self.cmd.cursor.description]))
Expand Down
10 changes: 5 additions & 5 deletions src/crate/crash/test_command.py
Expand Up @@ -566,7 +566,7 @@ def test_verbose_with_error_trace(self):
cmd.logger = Mock()
cmd.cursor.execute = Mock(side_effect=ProgrammingError(msg="the error message",
error_trace="error trace"))
cmd.execute("select invalid statement")
cmd._exec_and_print("select invalid statement")
cmd.logger.critical.assert_any_call("the error message")
cmd.logger.critical.assert_called_with("\nerror trace")

Expand All @@ -575,7 +575,7 @@ def test_verbose_no_error_trace(self):
cmd.logger = Mock()
cmd.cursor.execute = Mock(side_effect=ProgrammingError(msg="the error message",
error_trace=None))
cmd.execute("select invalid statement")
cmd._exec_and_print("select invalid statement")
# only the message is logged
cmd.logger.critical.assert_called_once_with("the error message")

Expand Down Expand Up @@ -681,7 +681,7 @@ def test_wrong_host_format(self):

def test_command_timeout(self):
with CrateShell(self.crate_host) as crash:
crash.execute("""
crash.process("""
CREATE FUNCTION fib(long)
RETURNS LONG
LANGUAGE javascript AS '
Expand All @@ -699,15 +699,15 @@ def test_command_timeout(self):
error_trace=False,
timeout=timeout) as crash:
crash.logger = Mock()
crash.execute(slow_query)
crash.process(slow_query)
crash.logger.warn.assert_any_call("Use \\connect <server> to connect to one or more servers first.")

# with verbose
with CrateShell(self.crate_host,
error_trace=True,
timeout=timeout) as crash:
crash.logger = Mock()
crash.execute(slow_query)
crash.process(slow_query)
crash.logger.warn.assert_any_call("No more Servers available, exception from last server: HTTPConnectionPool(host='127.0.0.1', port=44209): Read timed out. (read timeout=0.1)")
crash.logger.warn.assert_any_call("Use \\connect <server> to connect to one or more servers first.")

Expand Down

0 comments on commit 9bd9ce3

Please sign in to comment.