diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 413b7495..0a144725 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,13 +7,20 @@ on: jobs: linux: - - runs-on: ubuntu-latest - strategy: matrix: python-version: [3.6, 3.7, 3.8, 3.9] + include: + - python-version: 3.6 + os: ubuntu-16.04 # MySQL 5.7.32 + - python-version: 3.7 + os: ubuntu-18.04 # MySQL 5.7.32 + - python-version: 3.8 + os: ubuntu-18.04 # MySQL 5.7.32 + - python-version: 3.9 + os: ubuntu-20.04 # MySQL 8.0.22 + runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v2 @@ -42,6 +49,7 @@ jobs: - name: Pytest / behave env: PYTEST_PASSWORD: root + PYTEST_HOST: 127.0.0.1 run: | ./setup.py test --pytest-args="--cov-report= --cov=mycli" diff --git a/README.md b/README.md index c709eb89..46b5fd3f 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,7 @@ # mycli -[![Build Status](https://travis-ci.org/dbcli/mycli.svg?branch=master)](https://travis-ci.org/dbcli/mycli) +[![Build Status](https://github.com/dbcli/mycli/workflows/mycli/badge.svg)](https://github.com/dbcli/mycli/actions?query=workflow%3Amycli) [![PyPI](https://img.shields.io/pypi/v/mycli.svg?style=plastic)](https://pypi.python.org/pypi/mycli) -[![Join the chat at https://gitter.im/dbcli/mycli](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/dbcli/mycli?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) A command line client for MySQL that can do auto-completion and syntax highlighting. @@ -53,6 +52,7 @@ $ sudo apt-get install mycli # Only on debian or ubuntu -h, --host TEXT Host address of the database. -P, --port INTEGER Port number to use for connection. Honors $MYSQL_TCP_PORT. + -u, --user TEXT User name to connect to the database. -S, --socket TEXT The socket file to use for connection. -p, --password TEXT Password to connect to the database. @@ -63,8 +63,11 @@ $ sudo apt-get install mycli # Only on debian or ubuntu --ssh-password TEXT Password to connect to ssh server. --ssh-key-filename TEXT Private key filename (identify file) for the ssh connection. + --ssh-config-path TEXT Path to ssh configuration. - --ssh-config-host TEXT Host for ssh server in ssh configurations (requires paramiko). + --ssh-config-host TEXT Host to connect to ssh server reading from ssh + configuration. + --ssl-ca PATH CA file in PEM format. --ssl-capath TEXT CA directory. --ssl-cert PATH X509 cert in PEM format. @@ -73,33 +76,43 @@ $ sudo apt-get install mycli # Only on debian or ubuntu --ssl-verify-server-cert Verify server's "Common Name" in its cert against hostname used when connecting. This option is disabled by default. + -V, --version Output mycli's version. -v, --verbose Verbose output. -D, --database TEXT Database to use. -d, --dsn TEXT Use DSN configured into the [alias_dsn] section of myclirc file. + --list-dsn list of DSN configured into the [alias_dsn] section of myclirc file. - --list-ssh-config list ssh configurations in the ssh config (requires paramiko). + + --list-ssh-config list ssh configurations in the ssh config + (requires paramiko). + -R, --prompt TEXT Prompt format (Default: "\t \u@\h:\d> "). -l, --logfile FILENAME Log every query and its results to a file. --defaults-group-suffix TEXT Read MySQL config groups with the specified suffix. + --defaults-file PATH Only read MySQL options from the given file. --myclirc PATH Location of myclirc file. --auto-vertical-output Automatically switch to vertical output mode if the result is wider than the terminal width. + -t, --table Display batch output in table format. --csv Display batch output in CSV format. --warn / --no-warn Warn before running a destructive query. --local-infile BOOLEAN Enable/disable LOAD DATA LOCAL INFILE. - --login-path TEXT Read this path from the login file. + -g, --login-path TEXT Read this path from the login file. -e, --execute TEXT Execute command and quit. --init-command TEXT SQL statement to execute after connecting. --charset TEXT Character set for MySQL session. + --password-file PATH File or FIFO path containing the password + to connect to the db if not specified otherwise --help Show this message and exit. + Features -------- diff --git a/changelog.md b/changelog.md index 71e6220b..21f9464b 100644 --- a/changelog.md +++ b/changelog.md @@ -1,15 +1,27 @@ TBD -======= +=== Bug Fixes: ---------- * Allow `FileNotFound` exception for SSH config files. +* Fix startup error on MySQL < 5.0.22 +* Check error code rather than message for Access Denied error +* Fix login with ~/.my.cnf files Features: --------- * Add `-g` shortcut to option `--login-path`. +* Alt-Enter dispatches the command in multi-line mode. +* Allow to pass a file or FIFO path with --password-file when password is not specified or is failing (as suggested in this best-practice https://www.netmeister.org/blog/passing-passwords.html) * Reuse the same SSH connection for both main thread and completion thread. +Internal: +--------- +* Remove unused function is_open_quote() +* Use importlib, instead of file links, to locate resources +* Test various host-port combinations in command line arguments +* Switched from Cryptography to pyaes for decrypting mylogin.cnf + 1.23.2 === diff --git a/mycli/AUTHORS b/mycli/AUTHORS index c871f510..8cdea919 100644 --- a/mycli/AUTHORS +++ b/mycli/AUTHORS @@ -84,6 +84,7 @@ Contributors: * xeron * 0xflotus * Seamile + * Jerome Provensal Creator: -------- diff --git a/mycli/clibuffer.py b/mycli/clibuffer.py index c0cb5c1b..81353b63 100644 --- a/mycli/clibuffer.py +++ b/mycli/clibuffer.py @@ -1,7 +1,6 @@ from prompt_toolkit.enums import DEFAULT_BUFFER from prompt_toolkit.filters import Condition from prompt_toolkit.application import get_app -from .packages.parseutils import is_open_quote from .packages import special diff --git a/mycli/config.py b/mycli/config.py index 9c592fb5..5d711093 100644 --- a/mycli/config.py +++ b/mycli/config.py @@ -1,5 +1,3 @@ -import io -import shutil from copy import copy from io import BytesIO, TextIOWrapper import logging @@ -7,11 +5,16 @@ from os.path import exists import struct import sys -from typing import Union +from typing import Union, IO from configobj import ConfigObj, ConfigObjError -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from cryptography.hazmat.backends import default_backend +import pyaes + +try: + import importlib.resources as resources +except ImportError: + # Python < 3.7 + import importlib_resources as resources try: basestring @@ -49,9 +52,9 @@ def read_config_file(f, list_values=True): config = ConfigObj(f, interpolation=False, encoding='utf8', list_values=list_values) except ConfigObjError as e: - log(logger, logging.ERROR, "Unable to parse line {0} of config file " + log(logger, logging.WARNING, "Unable to parse line {0} of config file " "'{1}'.".format(e.line_number, f)) - log(logger, logging.ERROR, "Using successfully parsed config values.") + log(logger, logging.WARNING, "Using successfully parsed config values.") return e.config except (IOError, OSError) as e: log(logger, logging.WARNING, "You don't have permission to read " @@ -61,7 +64,7 @@ def read_config_file(f, list_values=True): return config -def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list: +def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list: """Get a list of configuration files that are included into config_path with !includedir directive. @@ -95,7 +98,7 @@ def get_included_configs(config_file: Union[str, io.TextIOWrapper]) -> list: def read_config_files(files, list_values=True): """Read and merge a list of config files.""" - config = ConfigObj(list_values=list_values) + config = create_default_config(list_values=list_values) _files = copy(files) while _files: _file = _files.pop(0) @@ -112,12 +115,21 @@ def read_config_files(files, list_values=True): return config -def write_default_config(source, destination, overwrite=False): +def create_default_config(list_values=True): + import mycli + default_config_file = resources.open_text(mycli, 'myclirc') + return read_config_file(default_config_file, list_values=list_values) + + +def write_default_config(destination, overwrite=False): + import mycli + default_config = resources.read_text(mycli, 'myclirc') destination = os.path.expanduser(destination) if not overwrite and exists(destination): return - shutil.copyfile(source, destination) + with open(destination, 'w') as f: + f.write(default_config) def get_mylogin_cnf_path(): @@ -160,6 +172,58 @@ def open_mylogin_cnf(name): return TextIOWrapper(plaintext) +# TODO reuse code between encryption an decryption +def encrypt_mylogin_cnf(plaintext: IO[str]): + """Encryption of .mylogin.cnf file, analogous to calling + mysql_config_editor. + + Code is based on the python implementation by Kristian Koehntopp + https://github.com/isotopp/mysql-config-coder + + """ + def realkey(key): + """Create the AES key from the login key.""" + rkey = bytearray(16) + for i in range(len(key)): + rkey[i % 16] ^= key[i] + return bytes(rkey) + + def encode_line(plaintext, real_key, buf_len): + aes = pyaes.AESModeOfOperationECB(real_key) + text_len = len(plaintext) + pad_len = buf_len - text_len + pad_chr = bytes(chr(pad_len), "utf8") + plaintext = plaintext.encode() + pad_chr * pad_len + encrypted_text = b''.join( + [aes.encrypt(plaintext[i: i + 16]) + for i in range(0, len(plaintext), 16)] + ) + return encrypted_text + + LOGIN_KEY_LENGTH = 20 + key = os.urandom(LOGIN_KEY_LENGTH) + real_key = realkey(key) + + outfile = BytesIO() + + outfile.write(struct.pack("i", 0)) + outfile.write(key) + + while True: + line = plaintext.readline() + if not line: + break + real_len = len(line) + pad_len = (int(real_len / 16) + 1) * 16 + + outfile.write(struct.pack("i", pad_len)) + x = encode_line(line, real_key, pad_len) + outfile.write(x) + + outfile.seek(0) + return outfile + + def read_and_decrypt_mylogin_cnf(f): """Read and decrypt the contents of .mylogin.cnf. @@ -201,11 +265,9 @@ def read_and_decrypt_mylogin_cnf(f): return None rkey = struct.pack('16B', *rkey) - # Create a decryptor object using the key. - decryptor = _get_decryptor(rkey) - # Create a bytes buffer to hold the plaintext. plaintext = BytesIO() + aes = pyaes.AESModeOfOperationECB(rkey) while True: # Read the length of the ciphertext. @@ -216,7 +278,10 @@ def read_and_decrypt_mylogin_cnf(f): # Read cipher_len bytes from the file and decrypt. cipher = f.read(cipher_len) - plain = _remove_pad(decryptor.update(cipher)) + plain = _remove_pad( + b''.join([aes.decrypt(cipher[i: i + 16]) + for i in range(0, cipher_len, 16)]) + ) if plain is False: continue plaintext.write(plain) @@ -260,15 +325,8 @@ def strip_matching_quotes(s): return s -def _get_decryptor(key): - """Get the AES decryptor.""" - c = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend()) - return c.decryptor() - - def _remove_pad(line): """Remove the pad from the *line*.""" - pad_length = ord(line[-1:]) try: # Determine pad length. pad_length = ord(line[-1:]) diff --git a/mycli/key_bindings.py b/mycli/key_bindings.py index 57b917bf..4a24c82b 100644 --- a/mycli/key_bindings.py +++ b/mycli/key_bindings.py @@ -78,8 +78,12 @@ def _(event): @kb.add('escape', 'enter') def _(event): - """Introduces a line break regardless of multi-line mode or not.""" + """Introduces a line break in multi-line mode, or dispatches the + command in single-line mode.""" _logger.debug('Detected alt-enter key.') - event.app.current_buffer.insert_text('\n') + if mycli.multi_line: + event.app.current_buffer.validate_and_handle() + else: + event.app.current_buffer.insert_text('\n') return kb diff --git a/mycli/main.py b/mycli/main.py index 23f50913..eefd1cf8 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1,9 +1,12 @@ +from collections import defaultdict +from io import open import os import sys import traceback import logging import threading import re +import stat import fileinput from collections import namedtuple try: @@ -13,7 +16,6 @@ from time import time from datetime import datetime from random import choice -from io import open from pymysql import OperationalError from cli_helpers.tabular_output import TabularOutputFormatter @@ -44,7 +46,7 @@ from .sqlcompleter import SQLCompleter from .clitoolbar import create_toolbar_tokens_func from .clistyle import style_factory, style_factory_output -from .sqlexecute import FIELD_TYPES, SQLExecute +from .sqlexecute import FIELD_TYPES, SQLExecute, ERROR_CODE_ACCESS_DENIED from .clibuffer import cli_is_multiline from .completion_refresher import CompletionRefresher from .config import (write_default_config, get_mylogin_cnf_path, @@ -52,7 +54,7 @@ strip_matching_quotes) from .key_bindings import mycli_bindings from .lexer import MyCliLexer -from .__init__ import __version__ +from . import __version__ from .compat import WIN from .packages.filepaths import dir_path_exists, guess_socket_location @@ -67,11 +69,19 @@ from urllib.parse import urlparse from urllib.parse import unquote +try: + import importlib.resources as resources +except ImportError: + # Python < 3.7 + import importlib_resources as resources # Query tuples are used for maintaining history Query = namedtuple('Query', ['query', 'successful', 'mutating']) -PACKAGE_ROOT = os.path.abspath(os.path.dirname(__file__)) +SUPPORT_INFO = ( + 'Home: http://mycli.net\n' + 'Bug tracker: https://github.com/dbcli/mycli/issues' +) class MyCli(object): @@ -98,7 +108,6 @@ class MyCli(object): os.path.join(os.path.expanduser(xdg_config_home), "mycli", "myclirc") ] - default_config_file = os.path.join(PACKAGE_ROOT, 'myclirc') pwd_config_file = os.path.join(os.getcwd(), ".myclirc") def __init__(self, sqlexecute=None, prompt=None, @@ -118,7 +127,7 @@ def __init__(self, sqlexecute=None, prompt=None, self.cnf_files = [defaults_file] # Load config. - config_files = ([self.default_config_file] + self.system_config_files + + config_files = (self.system_config_files + [myclirc] + [self.pwd_config_file]) c = self.config = read_config_files(config_files) self.multi_line = c['main'].as_bool('multi_line') @@ -150,7 +159,7 @@ def __init__(self, sqlexecute=None, prompt=None, # Write user config if system config wasn't the last config loaded. if c.filename not in self.system_config_files and not os.path.exists(myclirc): - write_default_config(self.default_config_file, myclirc) + write_default_config(myclirc) # audit log if self.logfile is None and 'audit_log' in c['main']: @@ -324,20 +333,33 @@ def read_my_cnf_files(self, files, keys): cnf = read_config_files(files, list_values=False) sections = ['client', 'mysqld'] + key_transformations = { + 'mysqld': { + 'socket': 'default_socket', + 'port': 'default_port', + }, + } + if self.login_path and self.login_path != 'client': sections.append(self.login_path) if self.defaults_suffix: sections.extend([sect + self.defaults_suffix for sect in sections]) - def get(key): - result = None - for sect in cnf: - if sect in sections and key in cnf[sect]: - result = strip_matching_quotes(cnf[sect][key]) - return result + configuration = defaultdict(lambda: None) + for key in keys: + for section in cnf: + if ( + section not in sections or + key not in cnf[section] + ): + continue + new_key = key_transformations.get(section, {}).get(key) or key + configuration[new_key] = strip_matching_quotes( + cnf[section][key]) + + return configuration - return {x: get(x) for x in keys} def merge_ssl_with_cnf(self, ssl, cnf): """Merge SSL configuration dict with cnf dict""" @@ -363,7 +385,8 @@ def merge_ssl_with_cnf(self, ssl, cnf): return merged def connect(self, database='', user='', passwd='', host='', port='', - socket='', charset='', local_infile='', ssl=None, init_command=''): + socket='', charset='', local_infile='', ssl=None, init_command='', + password_file=''): cnf = {'database': None, 'user': None, @@ -371,6 +394,7 @@ def connect(self, database='', user='', passwd='', host='', port='', 'host': None, 'port': None, 'socket': None, + 'default_socket': None, 'default-character-set': None, 'local-infile': None, 'loose-local-infile': None, @@ -384,18 +408,23 @@ def connect(self, database='', user='', passwd='', host='', port='', cnf = self.read_my_cnf_files(self.cnf_files, cnf.keys()) # Fall back to config values only if user did not specify a value. - database = database or cnf['database'] - # Socket interface not supported for SSH connections - if port or (host and host != 'localhost') or self.ssh_client: - socket = '' - else: - socket = socket or cnf['socket'] or guess_socket_location() user = user or cnf['user'] or os.getenv('USER') host = host or cnf['host'] - port = int(port or cnf['port'] or 3306) + port = port or cnf['port'] ssl = ssl or {} + port = port and int(port) + if not port and not self.ssh_client: + port = 3306 + if not host or host == 'localhost': + socket = ( + cnf['socket'] or + cnf['default_socket'] or + guess_socket_location() + ) + + passwd = passwd if isinstance(passwd, str) else cnf['password'] charset = charset or cnf['default-character-set'] or 'utf8' @@ -413,6 +442,10 @@ def connect(self, database='', user='', passwd='', host='', port='', if not any(v for v in ssl.values()): ssl = None + # if the passwd is not specfied try to set it using the password_file option + password_from_file = self.get_password_from_file(password_file) + passwd = passwd or password_from_file + # Connect to the database. def _connect(): @@ -422,9 +455,12 @@ def _connect(): local_infile, ssl, init_command, ssh_client=self.ssh_client ) except OperationalError as e: - if ('Access denied for user' in e.args[1]): - new_passwd = click.prompt('Password', hide_input=True, - show_default=False, type=str, err=True) + if e.args[0] == ERROR_CODE_ACCESS_DENIED: + if password_from_file: + new_passwd = password_from_file + else: + new_passwd = click.prompt('Password', hide_input=True, + show_default=False, type=str, err=True) self.sqlexecute = SQLExecute( database, user, new_passwd, host, port, socket, charset, local_infile, ssl, init_command, @@ -479,6 +515,17 @@ def _connect(): self.echo(str(e), err=True, fg='red') exit(1) + def get_password_from_file(self, password_file): + password_from_file = None + if password_file: + if (os.path.isfile(password_file) or stat.S_ISFIFO(os.stat(password_file).st_mode)) \ + and os.access(password_file, os.R_OK): + with open(password_file) as fp: + password_from_file = fp.readline() + password_from_file = password_from_file.rstrip().lstrip() + + return password_from_file + def handle_editor_command(self, text): r"""Editor command is any query that is prefixed or suffixed by a '\e'. The reason for a while loop is because a user might edit a query @@ -537,9 +584,6 @@ def run_cli(self): if self.smart_completion: self.refresh_completions() - author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS') - sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS') - history_file = os.path.expanduser( os.environ.get('MYCLI_HISTFILE', '~/.mycli-history')) if dir_path_exists(history_file): @@ -554,12 +598,10 @@ def run_cli(self): key_bindings = mycli_bindings(self) if not self.less_chatty: - print(' '.join(sqlexecute.server_type())) + print(sqlexecute.server_info) print('mycli', __version__) - print('Chat: https://gitter.im/dbcli/mycli') - print('Mail: https://groups.google.com/forum/#!forum/mycli-users') - print('Home: http://mycli.net') - print('Thanks to the contributor -', thanks_picker([author_file, sponsor_file])) + print(SUPPORT_INFO) + print('Thanks to the contributor -', thanks_picker()) def get_message(): prompt = self.get_prompt(self.prompt_format) @@ -857,8 +899,8 @@ def output(self, output, status=None): if not output_via_pager: # doesn't fit, flush buffer - for line in buf: - click.secho(line) + for buf_line in buf: + click.secho(buf_line) buf = [] else: click.secho(line) @@ -928,7 +970,7 @@ def get_prompt(self, string): string = string.replace('\\u', sqlexecute.user or '(none)') string = string.replace('\\h', host or '(none)') string = string.replace('\\d', sqlexecute.dbname or '(none)') - string = string.replace('\\t', sqlexecute.server_type()[0] or 'mycli') + string = string.replace('\\t', sqlexecute.server_info.species.name) string = string.replace('\\n', "\n") string = string.replace('\\D', now.strftime('%a %b %d %H:%M:%S %Y')) string = string.replace('\\m', now.strftime('%M')) @@ -1086,6 +1128,8 @@ def get_last_query(self): help='SQL statement to execute after connecting.') @click.option('--charset', type=str, help='Character set for MySQL session.') +@click.option('--password-file', type=click.Path(), + help='File or FIFO path containing the password to connect to the db if not specified otherwise.') @click.argument('database', default='', nargs=1) def cli(database, user, host, port, socket, password, dbname, version, verbose, prompt, logfile, defaults_group_suffix, @@ -1094,7 +1138,7 @@ def cli(database, user, host, port, socket, password, dbname, ssl_verify_server_cert, table, csv, warn, execute, myclirc, dsn, list_dsn, ssh_user, ssh_host, ssh_port, ssh_password, ssh_key_filename, list_ssh_config, ssh_config_path, ssh_config_host, - init_command, charset): + init_command, charset, password_file): """A MySQL terminal client with auto-completion and syntax highlighting. \b @@ -1230,6 +1274,7 @@ def cli(database, user, host, port, socket, password, dbname, ssl=ssl, init_command=init_command, charset=charset, + password_file=password_file, ) mycli.logger.debug('Launch Params: \n' @@ -1332,9 +1377,15 @@ def is_select(status): return status.split(None, 1)[0].lower() == 'select' -def thanks_picker(files=()): +def thanks_picker(): + import mycli + lines = ( + resources.read_text(mycli, 'AUTHORS') + + resources.read_text(mycli, 'SPONSORS') + ).split('\n') + contents = [] - for line in fileinput.input(files=files): + for line in lines: m = re.match(r'^ *\* (.*)', line) if m: contents.append(m.group(1)) diff --git a/mycli/packages/completion_engine.py b/mycli/packages/completion_engine.py index 3cff2ccc..c7db06cb 100644 --- a/mycli/packages/completion_engine.py +++ b/mycli/packages/completion_engine.py @@ -1,5 +1,3 @@ -import os -import sys import sqlparse from sqlparse.sql import Comparison, Identifier, Where from .parseutils import last_word, extract_tables, find_prev_keyword diff --git a/mycli/packages/parseutils.py b/mycli/packages/parseutils.py index 268e04e4..fa5f2c9e 100644 --- a/mycli/packages/parseutils.py +++ b/mycli/packages/parseutils.py @@ -12,7 +12,8 @@ 'most_punctuations': re.compile(r'([^\.():,\s]+)$'), # This matches everything except a space. 'all_punctuations': re.compile(r'([^\s]+)$'), - } +} + def last_word(text, include='alphanum_underscore'): r""" @@ -226,14 +227,6 @@ def is_destructive(queries): return False -def is_open_quote(sql): - """Returns true if the query contains an unclosed quote.""" - - # parsed can contain one or more semi-colon separated commands - parsed = sqlparse.parse(sql) - return any(_parsed_is_open_quote(p) for p in parsed) - - if __name__ == '__main__': sql = 'select * from (select t. from tabl t' print (extract_tables(sql)) @@ -263,5 +256,4 @@ def normalize_db_name(db): ) if database_token is not None and normalize_db_name(database_token.get_name()) == dbname: result = keywords[0].normalized == "DROP" - else: - return result + return result diff --git a/mycli/packages/special/iocommands.py b/mycli/packages/special/iocommands.py index 58066b82..01f3c7ba 100644 --- a/mycli/packages/special/iocommands.py +++ b/mycli/packages/special/iocommands.py @@ -302,7 +302,7 @@ def execute_system_command(arg, **_): usage = "Syntax: system [command].\n" if not arg: - return [(None, None, None, usage)] + return [(None, None, None, usage)] try: command = arg.strip() diff --git a/mycli/packages/tabular_output/sql_format.py b/mycli/packages/tabular_output/sql_format.py index 730e6332..e6587bd3 100644 --- a/mycli/packages/tabular_output/sql_format.py +++ b/mycli/packages/tabular_output/sql_format.py @@ -1,6 +1,5 @@ """Format adapter for sql.""" -from cli_helpers.utils import filter_dict_by_key from mycli.packages.parseutils import extract_tables supported_formats = ('sql-insert', 'sql-update', 'sql-update-1', diff --git a/mycli/sqlcompleter.py b/mycli/sqlcompleter.py index 73b9b449..3656aa69 100644 --- a/mycli/sqlcompleter.py +++ b/mycli/sqlcompleter.py @@ -72,7 +72,7 @@ def escape_name(self, name): if name and ((not self.name_pattern.match(name)) or (name.upper() in self.reserved_words) or (name.upper() in self.functions)): - name = '`%s`' % name + name = '`%s`' % name return name diff --git a/mycli/sqlexecute.py b/mycli/sqlexecute.py index 5529974e..36592cac 100644 --- a/mycli/sqlexecute.py +++ b/mycli/sqlexecute.py @@ -1,6 +1,8 @@ +import enum import logging +import re + import pymysql -import sqlparse from .packages import special from pymysql.constants import FIELD_TYPE from pymysql.converters import (convert_datetime, @@ -16,17 +18,70 @@ }) +ERROR_CODE_ACCESS_DENIED = 1045 + + +class ServerSpecies(enum.Enum): + MySQL = 'MySQL' + MariaDB = 'MariaDB' + Percona = 'Percona' + Unknown = 'MySQL' + + +class ServerInfo: + def __init__(self, species, version_str): + self.species = species + self.version_str = version_str + self.version = self.calc_mysql_version_value(version_str) + + @staticmethod + def calc_mysql_version_value(version_str) -> int: + if not version_str or not isinstance(version_str, str): + return 0 + try: + major, minor, patch = version_str.split('.') + except ValueError: + return 0 + else: + return int(major) * 10_000 + int(minor) * 100 + int(patch) + + @classmethod + def from_version_string(cls, version_string): + if not version_string: + return cls(ServerSpecies.Unknown, '') + + re_species = ( + (r'(?P[0-9\.]+)-MariaDB', ServerSpecies.MariaDB), + (r'(?P[0-9\.]+)[a-z0-9]*-(?P[0-9]+$)', + ServerSpecies.Percona), + (r'(?P[0-9\.]+)[a-z0-9]*-(?P[A-Za-z0-9_]+)', + ServerSpecies.MySQL), + ) + for regexp, species in re_species: + match = re.search(regexp, version_string) + if match is not None: + parsed_version = match.group('version') + detected_species = species + break + else: + detected_species = ServerSpecies.Unknown + parsed_version = '' + + return cls(detected_species, parsed_version) + + def __str__(self): + if self.species: + return f'{self.species.value} {self.version_str}' + else: + return self.version_str + + class SQLExecute(object): databases_query = '''SHOW DATABASES''' tables_query = '''SHOW TABLES''' - version_query = '''SELECT @@VERSION''' - - version_comment_query = '''SELECT @@VERSION_COMMENT''' - version_comment_query_mysql4 = '''SHOW VARIABLES LIKE "version_comment"''' - show_candidates_query = '''SELECT name from mysql.help_topic WHERE name like "SHOW %"''' users_query = '''SELECT CONCAT("'", user, "'@'",host,"'") FROM mysql.user''' @@ -50,7 +105,7 @@ def __init__(self, database, user, password, host, port, socket, charset, self.charset = charset self.local_infile = local_infile self.ssl = ssl - self._server_type = None + self.server_info = None self.connection_id = None self.init_command = init_command self.ssh_client = ssh_client @@ -132,6 +187,7 @@ def connect(self, database=None, user=None, password=None, host=None, self.init_command = init_command # retrieve connection id self.reset_connection_id() + self.server_info = ServerInfo.from_version_string(conn.server_version) def run(self, statement): """Execute the sql in the database and return the results. The results @@ -248,37 +304,6 @@ def users(self): for row in cur: yield row - def server_type(self): - if self._server_type: - return self._server_type - with self.conn.cursor() as cur: - _logger.debug('Version Query. sql: %r', self.version_query) - cur.execute(self.version_query) - version = cur.fetchone()[0] - if version[0] == '4': - _logger.debug('Version Comment. sql: %r', - self.version_comment_query_mysql4) - cur.execute(self.version_comment_query_mysql4) - version_comment = cur.fetchone()[1].lower() - if isinstance(version_comment, bytes): - # with python3 this query returns bytes - version_comment = version_comment.decode('utf-8') - else: - _logger.debug('Version Comment. sql: %r', - self.version_comment_query) - cur.execute(self.version_comment_query) - version_comment = cur.fetchone()[0].lower() - - if 'mariadb' in version_comment: - product_type = 'mariadb' - elif 'percona' in version_comment: - product_type = 'percona' - else: - product_type = 'mysql' - - self._server_type = (product_type, version) - return self._server_type - def get_connection_id(self): if not self.connection_id: self.reset_connection_id() diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..5422131c --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = --ignore=mycli/packages/paramiko_stub/__init__.py diff --git a/release.py b/release.py index 30c41b3f..3f18f03f 100755 --- a/release.py +++ b/release.py @@ -1,6 +1,5 @@ """A script to publish a release of mycli to PyPI.""" -import io from optparse import OptionParser import re import subprocess diff --git a/setup.py b/setup.py index 4aa7f91a..5acbae7c 100755 --- a/setup.py +++ b/setup.py @@ -23,11 +23,14 @@ 'PyMySQL >= 0.9.2', 'sqlparse>=0.3.0,<0.4.0', 'configobj >= 5.0.5', - 'cryptography >= 1.0.0', 'cli_helpers[styles] >= 2.0.1', - 'pyperclip >= 1.8.1' + 'pyperclip >= 1.8.1', + 'pyaes >= 1.6.1' ] +if sys.version_info.minor < 9: + install_requirements.append('importlib_resources >= 5.0.0') + class lint(Command): description = 'check code against PEP 8 (and fix violations)' diff --git a/test/features/connection.feature b/test/features/connection.feature new file mode 100644 index 00000000..b06935ea --- /dev/null +++ b/test/features/connection.feature @@ -0,0 +1,35 @@ +Feature: connect to a database: + + @requires_local_db + Scenario: run mycli on localhost without port + When we run mycli with arguments "host=localhost" without arguments "port" + When we query "status" + Then status contains "via UNIX socket" + + Scenario: run mycli on TCP host without port + When we run mycli without arguments "port" + When we query "status" + Then status contains "via TCP/IP" + + Scenario: run mycli with port but without host + When we run mycli without arguments "host" + When we query "status" + Then status contains "via TCP/IP" + + @requires_local_db + Scenario: run mycli without host and port + When we run mycli without arguments "host port" + When we query "status" + Then status contains "via UNIX socket" + + Scenario: run mycli with my.cnf configuration + When we create my.cnf file + When we run mycli without arguments "host port user pass defaults_file" + Then we are logged in + + Scenario: run mycli with mylogin.cnf configuration + When we create mylogin.cnf file + When we run mycli with arguments "login_path=test_login_path" without arguments "host port user pass defaults_file" + Then we are logged in + + diff --git a/test/features/environment.py b/test/features/environment.py index 98c20049..1ea0f086 100644 --- a/test/features/environment.py +++ b/test/features/environment.py @@ -1,4 +1,5 @@ import os +import shutil import sys from tempfile import mkstemp @@ -11,6 +12,24 @@ test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log') +SELF_CONNECTING_FEATURES = ( + 'test/features/connection.feature', +) + + +MY_CNF_PATH = os.path.expanduser('~/.my.cnf') +MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup' +MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf') +MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup' + + +def get_db_name_from_context(context): + return context.config.userdata.get( + 'my_test_db', None + ) or "mycli_behave_tests" + + + def before_all(context): """Set env parameters.""" os.environ['LINES'] = "100" @@ -22,7 +41,7 @@ def before_all(context): test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) login_path_file = os.path.join(test_dir, 'mylogin.cnf') - os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file +# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file context.package_root = os.path.abspath( os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) @@ -33,8 +52,7 @@ def before_all(context): context.exit_sent = False vi = '_'.join([str(x) for x in sys.version_info[:3]]) - db_name = context.config.userdata.get( - 'my_test_db', None) or "mycli_behave_tests" + db_name = get_db_name_from_context(context) db_name_full = '{0}_{1}'.format(db_name, vi) # Store get params from config/environment variables @@ -104,11 +122,18 @@ def before_step(context, _): context.atprompt = False -def before_scenario(context, _): +def before_scenario(context, arg): with open(test_log_file, 'w') as f: f.write('') - run_cli(context) - wait_prompt(context) + if arg.location.filename not in SELF_CONNECTING_FEATURES: + run_cli(context) + wait_prompt(context) + + if os.path.exists(MY_CNF_PATH): + shutil.move(MY_CNF_PATH, MY_CNF_BACKUP_PATH) + + if os.path.exists(MYLOGIN_CNF_PATH): + shutil.move(MYLOGIN_CNF_PATH, MYLOGIN_CNF_BACKUP_PATH) def after_scenario(context, _): @@ -134,6 +159,17 @@ def after_scenario(context, _): context.cli.sendcontrol('d') context.cli.expect_exact(pexpect.EOF, timeout=5) + if os.path.exists(MY_CNF_BACKUP_PATH): + shutil.move(MY_CNF_BACKUP_PATH, MY_CNF_PATH) + + if os.path.exists(MYLOGIN_CNF_BACKUP_PATH): + shutil.move(MYLOGIN_CNF_BACKUP_PATH, MYLOGIN_CNF_PATH) + elif os.path.exists(MYLOGIN_CNF_PATH): + # This file was moved in `before_scenario`. + # If it exists now, it has been created during a test + os.remove(MYLOGIN_CNF_PATH) + + # TODO: uncomment to debug a failure # def after_step(context, step): # if step.status == "failed": diff --git a/test/features/steps/auto_vertical.py b/test/features/steps/auto_vertical.py index 974740d7..e1cb26f8 100644 --- a/test/features/steps/auto_vertical.py +++ b/test/features/steps/auto_vertical.py @@ -3,11 +3,12 @@ from behave import then, when import wrappers +from utils import parse_cli_args_to_dict @when('we run dbcli with {arg}') def step_run_cli_with_arg(context, arg): - wrappers.run_cli(context, run_args=arg.split('=')) + wrappers.run_cli(context, run_args=parse_cli_args_to_dict(arg)) @when('we execute a small query') diff --git a/test/features/steps/connection.py b/test/features/steps/connection.py new file mode 100644 index 00000000..e16dd867 --- /dev/null +++ b/test/features/steps/connection.py @@ -0,0 +1,71 @@ +import io +import os +import shlex + +from behave import when, then +import pexpect + +import wrappers +from test.features.steps.utils import parse_cli_args_to_dict +from test.features.environment import MY_CNF_PATH, MYLOGIN_CNF_PATH, get_db_name_from_context +from test.utils import HOST, PORT, USER, PASSWORD +from mycli.config import encrypt_mylogin_cnf + + +TEST_LOGIN_PATH = 'test_login_path' + + +@when('we run mycli with arguments "{exact_args}" without arguments "{excluded_args}"') +@when('we run mycli without arguments "{excluded_args}"') +def step_run_cli_without_args(context, excluded_args, exact_args=''): + wrappers.run_cli( + context, + run_args=parse_cli_args_to_dict(exact_args), + exclude_args=parse_cli_args_to_dict(excluded_args).keys() + ) + + +@then('status contains "{expression}"') +def status_contains(context, expression): + wrappers.expect_exact(context, f'{expression}', timeout=5) + + # Normally, the shutdown after scenario waits for the prompt. + # But we may have changed the prompt, depending on parameters, + # so let's wait for its last character + context.cli.expect_exact('>') + context.atprompt = True + + +@when('we create my.cnf file') +def step_create_my_cnf_file(context): + my_cnf = ( + '[client]\n' + f'host = {HOST}\n' + f'port = {PORT}\n' + f'user = {USER}\n' + f'password = {PASSWORD}\n' + ) + with open(MY_CNF_PATH, 'w') as f: + f.write(my_cnf) + + +@when('we create mylogin.cnf file') +def step_create_mylogin_cnf_file(context): + os.environ.pop('MYSQL_TEST_LOGIN_FILE', None) + mylogin_cnf = ( + f'[{TEST_LOGIN_PATH}]\n' + f'host = {HOST}\n' + f'port = {PORT}\n' + f'user = {USER}\n' + f'password = {PASSWORD}\n' + ) + with open(MYLOGIN_CNF_PATH, 'wb') as f: + input_file = io.StringIO(mylogin_cnf) + f.write(encrypt_mylogin_cnf(input_file).read()) + + +@then('we are logged in') +def we_are_logged_in(context): + db_name = get_db_name_from_context(context) + context.cli.expect_exact(f'{db_name}>', timeout=5) + context.atprompt = True diff --git a/test/features/steps/utils.py b/test/features/steps/utils.py new file mode 100644 index 00000000..1ae63d2b --- /dev/null +++ b/test/features/steps/utils.py @@ -0,0 +1,12 @@ +import shlex + + +def parse_cli_args_to_dict(cli_args: str): + args_dict = {} + for arg in shlex.split(cli_args): + if '=' in arg: + key, value = arg.split('=') + args_dict[key] = value + else: + args_dict[arg] = None + return args_dict diff --git a/test/features/steps/wrappers.py b/test/features/steps/wrappers.py index de833dd2..6408f235 100644 --- a/test/features/steps/wrappers.py +++ b/test/features/steps/wrappers.py @@ -3,6 +3,7 @@ import sys import textwrap + try: from StringIO import StringIO except ImportError: @@ -13,7 +14,7 @@ def expect_exact(context, expected, timeout): timedout = False try: context.cli.expect_exact(expected, timeout=timeout) - except pexpect.exceptions.TIMEOUT: + except pexpect.TIMEOUT: timedout = True if timedout: # Strip color codes out of the output. @@ -46,21 +47,43 @@ def expect_pager(context, expected, timeout): context.conf['pager_boundary'], expected), timeout=timeout) -def run_cli(context, run_args=None): +def run_cli(context, run_args=None, exclude_args=None): """Run the process using pexpect.""" - run_args = run_args or [] - if context.conf.get('host', None): - run_args.extend(('-h', context.conf['host'])) - if context.conf.get('user', None): - run_args.extend(('-u', context.conf['user'])) - if context.conf.get('pass', None): - run_args.extend(('-p', context.conf['pass'])) - if context.conf.get('dbname', None): - run_args.extend(('-D', context.conf['dbname'])) - if context.conf.get('defaults-file', None): - run_args.extend(('--defaults-file', context.conf['defaults-file'])) - if context.conf.get('myclirc', None): - run_args.extend(('--myclirc', context.conf['myclirc'])) + run_args = run_args or {} + rendered_args = [] + exclude_args = set(exclude_args) if exclude_args else set() + + conf = dict(**context.conf) + conf.update(run_args) + + def add_arg(name, key, value): + if name not in exclude_args: + if value is not None: + rendered_args.extend((key, value)) + else: + rendered_args.append(key) + + if conf.get('host', None): + add_arg('host', '-h', conf['host']) + if conf.get('user', None): + add_arg('user', '-u', conf['user']) + if conf.get('pass', None): + add_arg('pass', '-p', conf['pass']) + if conf.get('port', None): + add_arg('port', '-P', str(conf['port'])) + if conf.get('dbname', None): + add_arg('dbname', '-D', conf['dbname']) + if conf.get('defaults-file', None): + add_arg('defaults_file', '--defaults-file', conf['defaults-file']) + if conf.get('myclirc', None): + add_arg('myclirc', '--myclirc', conf['myclirc']) + if conf.get('login_path'): + add_arg('login_path', '--login-path', conf['login_path']) + + for arg_name, arg_value in conf.items(): + if arg_name.startswith('-'): + add_arg(arg_name, arg_name, arg_value) + try: cli_cmd = context.conf['cli_command'] except KeyError: @@ -73,7 +96,7 @@ def run_cli(context, run_args=None): '"' ).format(sys.executable) - cmd_parts = [cli_cmd] + run_args + cmd_parts = [cli_cmd] + rendered_args cmd = ' '.join(cmd_parts) context.cli = pexpect.spawnu(cmd, cwd=context.package_root) context.logfile = StringIO() diff --git a/test/test_main.py b/test/test_main.py index b63f86a0..07d8f25a 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -3,8 +3,9 @@ import click from click.testing import CliRunner -from mycli.main import MyCli, cli, thanks_picker, PACKAGE_ROOT +from mycli.main import MyCli, cli, thanks_picker from mycli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from mycli.sqlexecute import ServerInfo from .utils import USER, HOST, PORT, PASSWORD, dbtest, run from textwrap import dedent @@ -140,10 +141,7 @@ def test_batch_mode_csv(executor): def test_thanks_picker_utf8(): - author_file = os.path.join(PACKAGE_ROOT, 'AUTHORS') - sponsor_file = os.path.join(PACKAGE_ROOT, 'SPONSORS') - - name = thanks_picker((author_file, sponsor_file)) + name = thanks_picker() assert name and isinstance(name, str) @@ -177,6 +175,7 @@ class TestExecute(): host = 'test' user = 'test' dbname = 'test' + server_info = ServerInfo.from_version_string('unknown') port = 0 def server_type(self): diff --git a/test/test_sqlexecute.py b/test/test_sqlexecute.py index 5168bf6f..0f38a97e 100644 --- a/test/test_sqlexecute.py +++ b/test/test_sqlexecute.py @@ -3,6 +3,7 @@ import pytest import pymysql +from mycli.sqlexecute import ServerInfo, ServerSpecies from .utils import run, dbtest, set_expanded_output, is_expanded_output @@ -270,3 +271,24 @@ def test_multiple_results(executor): 'status': '1 row in set'} ] assert results == expected + + +@pytest.mark.parametrize( + 'version_string, species, parsed_version_string, version', + ( + ('5.7.32-35', 'Percona', '5.7.32', 50732), + ('5.7.32-0ubuntu0.18.04.1', 'MySQL', '5.7.32', 50732), + ('10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), + ('5.5.5-10.5.8-MariaDB-1:10.5.8+maria~focal', 'MariaDB', '10.5.8', 100508), + ('5.0.16-pro-nt-log', 'MySQL', '5.0.16', 50016), + ('5.1.5a-alpha', 'MySQL', '5.1.5', 50105), + ('unexpected version string', None, '', 0), + ('', None, '', 0), + (None, None, '', 0), + ) +) +def test_version_parsing(version_string, species, parsed_version_string, version): + server_info = ServerInfo.from_version_string(version_string) + assert (server_info.species and server_info.species.name) == species or ServerSpecies.Unknown + assert server_info.version_str == parsed_version_string + assert server_info.version == version