From 475c6936d6b67e89fb4ec9ae8bd9c3609ef7437b Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Sat, 30 May 2026 11:26:59 -0400 Subject: [PATCH] migrate CLI argument handling back to main.py This is a bit of churn. We spun absolutely everything out of main.py in a series of steps, but it isn't required that the file have zero responsibilities. Here mycli/cli_args.py is removed, and its responsibilities brought back to main.py. Some values are also moved to constants.py to avoid circular imports. --- changelog.md | 1 + mycli/cli_args.py | 373 --------------------------- mycli/cli_runner.py | 12 +- mycli/client.py | 2 +- mycli/client_connection.py | 9 +- mycli/constants.py | 3 + mycli/main.py | 371 +++++++++++++++++++++++++- mycli/main_modes/batch.py | 2 +- mycli/main_modes/execute.py | 2 +- mycli/main_modes/list_ssh_config.py | 2 +- test/pytests/test_cli_args.py | 175 ------------- test/pytests/test_main.py | 173 ++++++++++++- test/pytests/test_main_regression.py | 2 +- 13 files changed, 562 insertions(+), 565 deletions(-) delete mode 100644 mycli/cli_args.py delete mode 100644 test/pytests/test_cli_args.py diff --git a/changelog.md b/changelog.md index 2a5f2d6c..07337c33 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,7 @@ Upcoming (TBD) Internal --------- * Factor `main.py` into several files using mixins. +* Move CLI argument handling back to `main.py`. * Update Python versions used in CI. * Add CI on macOS. diff --git a/mycli/cli_args.py b/mycli/cli_args.py deleted file mode 100644 index bf95f59d..00000000 --- a/mycli/cli_args.py +++ /dev/null @@ -1,373 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from io import TextIOWrapper -import os -import sys -from typing import Callable - -import click -import clickdc - -EMPTY_PASSWORD_FLAG_SENTINEL = -1 -DEFAULT_PROMPT = "\\t \\u@\\h:\\d> " - - -class IntOrStringClickParamType(click.ParamType): - name = 'text' # display as TEXT in helpdoc - - def convert(self, value, param, ctx): - if isinstance(value, int): - return value - elif isinstance(value, str): - return value - elif value is None: - return value - else: - self.fail('Not a valid password string', param, ctx) - - -INT_OR_STRING_CLICK_TYPE = IntOrStringClickParamType() - - -@dataclass(slots=True) -class CliArgs: - database: str | None = clickdc.argument( - type=str, - default=None, - nargs=1, - ) - host: str | None = clickdc.option( - '-h', - '--hostname', - 'host', - type=str, - envvar='MYSQL_HOST', - help='Host address of the database.', - ) - port: int | None = clickdc.option( - '-P', - type=int, - envvar='MYSQL_TCP_PORT', - help='Port number to use for connection. Honors $MYSQL_TCP_PORT.', - ) - user: str | None = clickdc.option( - '-u', - '--user', - '--username', - 'user', - type=str, - envvar='MYSQL_USER', - help='User name to connect to the database.', - ) - socket: str | None = clickdc.option( - '-S', - type=str, - envvar='MYSQL_UNIX_SOCKET', - help='The socket file to use for connection.', - ) - password: int | str | None = clickdc.option( - '-p', - '--pass', - '--password', - 'password', - type=INT_OR_STRING_CLICK_TYPE, - is_flag=False, - flag_value=EMPTY_PASSWORD_FLAG_SENTINEL, - help='Prompt for (or pass in cleartext) the password to connect to the database.', - ) - password_file: str | None = clickdc.option( - type=click.Path(), - help='File or FIFO path containing the password to connect to the db if not specified otherwise.', - ) - ssh_user: str | None = clickdc.option( - type=str, - help='User name to connect to ssh server.', - ) - ssh_host: str | None = clickdc.option( - type=str, - help='Host name to connect to ssh server.', - ) - ssh_port: int = clickdc.option( - type=int, - default=22, - help='Port to connect to ssh server.', - ) - ssh_password: str | None = clickdc.option( - type=str, - help='Password to connect to ssh server.', - ) - ssh_key_filename: str | None = clickdc.option( - type=str, - help='Private key filename (identify file) for the ssh connection.', - ) - ssh_config_path: str = clickdc.option( - type=str, - help='Path to ssh configuration.', - default=os.path.expanduser('~') + '/.ssh/config', - ) - ssh_config_host: str | None = clickdc.option( - type=str, - help='Host to connect to ssh server reading from ssh configuration.', - ) - list_ssh_config: bool = clickdc.option( - is_flag=True, - help='list ssh configurations in the ssh config (requires paramiko).', - ) - ssh_warning_off: bool = clickdc.option( - is_flag=True, - help='Suppress the SSH deprecation notice.', - ) - ssl_mode: str = clickdc.option( - type=click.Choice(['auto', 'on', 'off']), - help='Set desired SSL behavior. auto=preferred if TCP/IP, on=required, off=off.', - ) - deprecated_ssl: bool | None = clickdc.option( - '--ssl/--no-ssl', - 'deprecated_ssl', - default=None, - clickdc=None, - help='Enable SSL for connection (automatically enabled with other flags).', - ) - ssl_ca: str | None = clickdc.option( - type=click.Path(exists=True), - help='CA file in PEM format.', - ) - ssl_capath: str | None = clickdc.option( - type=click.Path(exists=True, file_okay=False, dir_okay=True), - help='CA directory.', - ) - ssl_cert: str | None = clickdc.option( - type=click.Path(exists=True), - help='X509 cert in PEM format.', - ) - ssl_key: str | None = clickdc.option( - type=click.Path(exists=True), - help='X509 key in PEM format.', - ) - ssl_cipher: str | None = clickdc.option( - type=str, - help='SSL cipher to use.', - ) - tls_version: str | None = clickdc.option( - type=click.Choice(['TLSv1', 'TLSv1.1', 'TLSv1.2', 'TLSv1.3'], case_sensitive=False), - help='TLS protocol version for secure connection.', - ) - ssl_verify_server_cert: bool = clickdc.option( - is_flag=True, - help=("""Verify server's "Common Name" in its cert against hostname used when connecting. This option is disabled by default."""), - ) - verbose: int = clickdc.option( - '-v', - count=True, - help='More verbose output and feedback. Can be given multiple times.', - ) - quiet: bool = clickdc.option( - '-q', - is_flag=True, - help='Less verbose output and feedback.', - ) - dbname: str | None = clickdc.option( - '-D', - '--database', - 'dbname', - type=str, - clickdc=None, - help='Database or DSN to use for the connection.', - ) - dsn: str = clickdc.option( - '-d', - type=str, - default='', - envvar='MYSQL_DSN', - help='DSN alias configured in the ~/.myclirc file, or a full DSN.', - ) - list_dsn: bool = clickdc.option( - is_flag=True, - help='Show list of DSN aliases configured in the [alias_dsn] section of ~/.myclirc.', - ) - prompt: str | None = clickdc.option( - '-R', - type=str, - help=f'Prompt format (Default: "{DEFAULT_PROMPT}").', - ) - toolbar: str | None = clickdc.option( - type=str, - help='Toolbar format.', - ) - logfile: TextIOWrapper | None = clickdc.option( - '-l', - type=click.File(mode='a', encoding='utf-8'), - help='Log every query and its results to a file.', - ) - checkpoint: TextIOWrapper | None = clickdc.option( - type=click.File(mode='a', encoding='utf-8'), - help='In batch or --execute mode, log successful queries to a file, and skipped with --resume.', - ) - resume: bool = clickdc.option( - '--resume', - is_flag=True, - help='In batch mode, resume after replaying statements in the --checkpoint file.', - ) - defaults_group_suffix: str | None = clickdc.option( - type=str, - help='Read MySQL config groups with the specified suffix.', - ) - defaults_file: str | None = clickdc.option( - type=click.Path(), - help='Only read MySQL options from the given file.', - ) - myclirc: str = clickdc.option( - type=click.Path(), - default='~/.myclirc', - help='Location of myclirc file.', - ) - auto_vertical_output: bool = clickdc.option( - is_flag=True, - help='Automatically switch to vertical output mode if the result is wider than the terminal width.', - ) - show_warnings: bool | None = clickdc.option( - '--show-warnings/--no-show-warnings', - is_flag=True, - default=None, - clickdc=None, - help='Automatically show warnings after executing a SQL statement.', - ) - table: bool = clickdc.option( - '-t', - is_flag=True, - help='Shorthand for --format=table.', - ) - csv: bool = clickdc.option( - is_flag=True, - help='Shorthand for --format=csv.', - ) - warn: bool | None = clickdc.option( - '--warn/--no-warn', - default=None, - clickdc=None, - help='Warn before running a destructive query.', - ) - local_infile: bool | None = clickdc.option( - type=bool, - is_flag=False, - default=None, - help='Enable/disable LOAD DATA LOCAL INFILE.', - ) - login_path: str | None = clickdc.option( - '-g', - type=str, - help='Read this path from the login file.', - ) - execute: str | None = clickdc.option( - '-e', - type=str, - help='Execute command and quit.', - ) - init_command: str | None = clickdc.option( - type=str, - help='SQL statement to execute after connecting.', - ) - unbuffered: bool | None = clickdc.option( - is_flag=True, - help='Instead of copying every row of data into a buffer, fetch rows as needed, to save memory.', - ) - character_set: str | None = clickdc.option( - '--charset', - '--character-set', - 'character_set', - type=str, - help='Character set for MySQL session.', - ) - batch: str | None = clickdc.option( - type=str, - help='SQL script to execute in batch mode.', - ) - noninteractive: bool = clickdc.option( - is_flag=True, - help="Don't prompt during batch input. Recommended.", - ) - format: str | None = clickdc.option( - type=click.Choice(['default', 'csv', 'tsv', 'table']), - help='Format for batch or --execute output.', - ) - throttle: float = clickdc.option( - type=float, - default=0.0, - help='Pause in seconds between queries in batch mode.', - ) - progress: bool = clickdc.option( - is_flag=True, - help='Show progress on the standard error with --batch.', - ) - use_keyring: str | None = clickdc.option( - type=click.Choice(['true', 'false', 'reset']), - default=None, - help='Store and retrieve passwords from the system keyring: true/false/reset.', - ) - keepalive_ticks: int | None = clickdc.option( - type=int, - help='Send regular keepalive pings to the connection, roughly every seconds.', - ) - checkup: bool = clickdc.option( - is_flag=True, - help='Run a checkup on your configuration.', - ) - - -def get_password_from_file(password_file: str | None) -> str | None: - if not password_file: - return None - try: - with open(password_file) as fp: - return fp.readline().removesuffix('\n') - except FileNotFoundError: - click.secho(f"Password file '{password_file}' not found", err=True, fg='red') - sys.exit(1) - except PermissionError: - click.secho(f"Permission denied reading password file '{password_file}'", err=True, fg='red') - sys.exit(1) - except IsADirectoryError: - click.secho(f"Path '{password_file}' is a directory, not a file", err=True, fg='red') - sys.exit(1) - except Exception as e: - click.secho(f"Error reading password file '{password_file}': {str(e)}", err=True, fg='red') - sys.exit(1) - - -def preprocess_cli_args( - cli_args: CliArgs, - is_valid_connection_scheme: Callable[[str], tuple[bool, str | None]], -) -> int: - if cli_args.database is None and isinstance(cli_args.password, str) and '://' in cli_args.password: - is_valid_scheme, scheme = is_valid_connection_scheme(cli_args.password) - if not is_valid_scheme: - click.secho(f'Error: Unknown connection scheme provided for DSN URI ({scheme}://)', err=True, fg='red') - sys.exit(1) - cli_args.database = cli_args.password - cli_args.password = EMPTY_PASSWORD_FLAG_SENTINEL - - if cli_args.password is None and cli_args.password_file: - password_from_file = get_password_from_file(cli_args.password_file) - if password_from_file is not None: - cli_args.password = password_from_file - - if cli_args.password is None and os.environ.get('MYSQL_PWD') is not None: - cli_args.password = os.environ.get('MYSQL_PWD') - - if cli_args.resume and not cli_args.checkpoint: - click.secho('Error: --resume requires a --checkpoint file.', err=True, fg='red') - sys.exit(1) - - if cli_args.resume and not cli_args.batch: - click.secho('Error: --resume requires a --batch file.', err=True, fg='red') - sys.exit(1) - - if cli_args.verbose and cli_args.quiet: - click.secho('Error: --verbose and --quiet are incompatible.', err=True, fg='red') - sys.exit(1) - elif cli_args.verbose: - return int(cli_args.verbose) - elif cli_args.quiet: - return -1 - return 0 diff --git a/mycli/cli_runner.py b/mycli/cli_runner.py index d4da00e1..2954ec29 100644 --- a/mycli/cli_runner.py +++ b/mycli/cli_runner.py @@ -3,23 +3,25 @@ import os import sys from textwrap import dedent -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable from urllib.parse import parse_qs, unquote, urlparse import click -from mycli.cli_args import EMPTY_PASSWORD_FLAG_SENTINEL, CliArgs, preprocess_cli_args from mycli.config import str_to_bool -from mycli.constants import ISSUES_URL, REPO_URL +from mycli.constants import EMPTY_PASSWORD_FLAG_SENTINEL, ISSUES_URL, REPO_URL from mycli.packages.ssh_utils import read_ssh_config +if TYPE_CHECKING: + from mycli.main import CliArgs + ClientFactory = Callable[..., Any] -def run_from_cli_args(cli_args: CliArgs, client_factory: ClientFactory) -> None: +def run_from_cli_args(cli_args: 'CliArgs', client_factory: ClientFactory) -> None: from mycli import main as main_module - cli_verbosity = preprocess_cli_args(cli_args, main_module.is_valid_connection_scheme) + cli_verbosity = main_module.preprocess_cli_args(cli_args, main_module.is_valid_connection_scheme) mycli = client_factory( prompt=cli_args.prompt, diff --git a/mycli/client.py b/mycli/client.py index a7c04e14..c153660a 100644 --- a/mycli/client.py +++ b/mycli/client.py @@ -18,10 +18,10 @@ llm_prompt_truncation, normalize_ssl_mode, ) -from mycli.cli_args import DEFAULT_PROMPT from mycli.client_commands import ClientCommandsMixin from mycli.client_connection import ClientConnectionMixin from mycli.client_query import ClientQueryMixin +from mycli.constants import DEFAULT_PROMPT from mycli.main_modes import repl as repl_package from mycli.output import OutputMixin from mycli.packages import special diff --git a/mycli/client_connection.py b/mycli/client_connection.py index c4492474..79f4e582 100644 --- a/mycli/client_connection.py +++ b/mycli/client_connection.py @@ -10,9 +10,14 @@ from pymysql.constants.CR import CR_SERVER_LOST from pymysql.constants.ER import ACCESS_DENIED_ERROR, HANDSHAKE_ERROR -from mycli.cli_args import EMPTY_PASSWORD_FLAG_SENTINEL from mycli.compat import WIN -from mycli.constants import DEFAULT_CHARSET, DEFAULT_HOST, DEFAULT_PORT, ER_MUST_CHANGE_PASSWORD_LOGIN +from mycli.constants import ( + DEFAULT_CHARSET, + DEFAULT_HOST, + DEFAULT_PORT, + EMPTY_PASSWORD_FLAG_SENTINEL, + ER_MUST_CHANGE_PASSWORD_LOGIN, +) try: from pwd import getpwuid diff --git a/mycli/constants.py b/mycli/constants.py index f6ef1900..65482412 100644 --- a/mycli/constants.py +++ b/mycli/constants.py @@ -17,3 +17,6 @@ # MySQL error codes not available in pymysql.constants.ER ER_MUST_CHANGE_PASSWORD_LOGIN = 1862 ER_MUST_CHANGE_PASSWORD = 1820 + +EMPTY_PASSWORD_FLAG_SENTINEL = -1 +DEFAULT_PROMPT = "\\t \\u@\\h:\\d> " diff --git a/mycli/main.py b/mycli/main.py index f3066ecb..5f15636c 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -1,7 +1,10 @@ from __future__ import annotations +from dataclasses import dataclass +from io import TextIOWrapper import os import sys +from typing import Callable from cli_helpers.tabular_output import TabularOutputFormatter from cli_helpers.tabular_output.output_formatter import MISSING_VALUE as DEFAULT_MISSING_VALUE @@ -13,13 +16,16 @@ from pymysql.constants.ER import ACCESS_DENIED_ERROR, HANDSHAKE_ERROR import mycli as mycli_package -from mycli.cli_args import EMPTY_PASSWORD_FLAG_SENTINEL, CliArgs from mycli.cli_runner import run_from_cli_args from mycli.client import MyCli from mycli.clistyle import style_factory_helpers, style_factory_ptoolkit from mycli.completion_refresher import CompletionRefresher from mycli.config import get_mylogin_cnf_path, open_mylogin_cnf, read_config_files, str_to_bool, write_default_config -from mycli.constants import ER_MUST_CHANGE_PASSWORD_LOGIN +from mycli.constants import ( + DEFAULT_PROMPT, + EMPTY_PASSWORD_FLAG_SENTINEL, + ER_MUST_CHANGE_PASSWORD_LOGIN, +) from mycli.main_modes.batch import main_batch_from_stdin, main_batch_with_progress_bar, main_batch_without_progress_bar from mycli.main_modes.checkup import main_checkup from mycli.main_modes.execute import main_execute_from_cli @@ -51,7 +57,6 @@ 'SQLExecute', 'SchemaPrefetcher', 'TabularOutputFormatter', - 'CliArgs', 'CompletionRefresher', 'click_entrypoint', 'confirm_destructive_query', @@ -84,6 +89,366 @@ ] +class IntOrStringClickParamType(click.ParamType): + name = 'text' # display as TEXT in helpdoc + + def convert(self, value, param, ctx): + if isinstance(value, int): + return value + elif isinstance(value, str): + return value + elif value is None: + return value + else: + self.fail('Not a valid password string', param, ctx) + + +INT_OR_STRING_CLICK_TYPE = IntOrStringClickParamType() + + +@dataclass(slots=True) +class CliArgs: + database: str | None = clickdc.argument( + type=str, + default=None, + nargs=1, + ) + host: str | None = clickdc.option( + '-h', + '--hostname', + 'host', + type=str, + envvar='MYSQL_HOST', + help='Host address of the database.', + ) + port: int | None = clickdc.option( + '-P', + type=int, + envvar='MYSQL_TCP_PORT', + help='Port number to use for connection. Honors $MYSQL_TCP_PORT.', + ) + user: str | None = clickdc.option( + '-u', + '--user', + '--username', + 'user', + type=str, + envvar='MYSQL_USER', + help='User name to connect to the database.', + ) + socket: str | None = clickdc.option( + '-S', + type=str, + envvar='MYSQL_UNIX_SOCKET', + help='The socket file to use for connection.', + ) + password: int | str | None = clickdc.option( + '-p', + '--pass', + '--password', + 'password', + type=INT_OR_STRING_CLICK_TYPE, + is_flag=False, + flag_value=EMPTY_PASSWORD_FLAG_SENTINEL, + help='Prompt for (or pass in cleartext) the password to connect to the database.', + ) + password_file: str | None = clickdc.option( + type=click.Path(), + help='File or FIFO path containing the password to connect to the db if not specified otherwise.', + ) + ssh_user: str | None = clickdc.option( + type=str, + help='User name to connect to ssh server.', + ) + ssh_host: str | None = clickdc.option( + type=str, + help='Host name to connect to ssh server.', + ) + ssh_port: int = clickdc.option( + type=int, + default=22, + help='Port to connect to ssh server.', + ) + ssh_password: str | None = clickdc.option( + type=str, + help='Password to connect to ssh server.', + ) + ssh_key_filename: str | None = clickdc.option( + type=str, + help='Private key filename (identify file) for the ssh connection.', + ) + ssh_config_path: str = clickdc.option( + type=str, + help='Path to ssh configuration.', + default=os.path.expanduser('~') + '/.ssh/config', + ) + ssh_config_host: str | None = clickdc.option( + type=str, + help='Host to connect to ssh server reading from ssh configuration.', + ) + list_ssh_config: bool = clickdc.option( + is_flag=True, + help='list ssh configurations in the ssh config (requires paramiko).', + ) + ssh_warning_off: bool = clickdc.option( + is_flag=True, + help='Suppress the SSH deprecation notice.', + ) + ssl_mode: str = clickdc.option( + type=click.Choice(['auto', 'on', 'off']), + help='Set desired SSL behavior. auto=preferred if TCP/IP, on=required, off=off.', + ) + deprecated_ssl: bool | None = clickdc.option( + '--ssl/--no-ssl', + 'deprecated_ssl', + default=None, + clickdc=None, + help='Enable SSL for connection (automatically enabled with other flags).', + ) + ssl_ca: str | None = clickdc.option( + type=click.Path(exists=True), + help='CA file in PEM format.', + ) + ssl_capath: str | None = clickdc.option( + type=click.Path(exists=True, file_okay=False, dir_okay=True), + help='CA directory.', + ) + ssl_cert: str | None = clickdc.option( + type=click.Path(exists=True), + help='X509 cert in PEM format.', + ) + ssl_key: str | None = clickdc.option( + type=click.Path(exists=True), + help='X509 key in PEM format.', + ) + ssl_cipher: str | None = clickdc.option( + type=str, + help='SSL cipher to use.', + ) + tls_version: str | None = clickdc.option( + type=click.Choice(['TLSv1', 'TLSv1.1', 'TLSv1.2', 'TLSv1.3'], case_sensitive=False), + help='TLS protocol version for secure connection.', + ) + ssl_verify_server_cert: bool = clickdc.option( + is_flag=True, + help=("""Verify server's "Common Name" in its cert against hostname used when connecting. This option is disabled by default."""), + ) + verbose: int = clickdc.option( + '-v', + count=True, + help='More verbose output and feedback. Can be given multiple times.', + ) + quiet: bool = clickdc.option( + '-q', + is_flag=True, + help='Less verbose output and feedback.', + ) + dbname: str | None = clickdc.option( + '-D', + '--database', + 'dbname', + type=str, + clickdc=None, + help='Database or DSN to use for the connection.', + ) + dsn: str = clickdc.option( + '-d', + type=str, + default='', + envvar='MYSQL_DSN', + help='DSN alias configured in the ~/.myclirc file, or a full DSN.', + ) + list_dsn: bool = clickdc.option( + is_flag=True, + help='Show list of DSN aliases configured in the [alias_dsn] section of ~/.myclirc.', + ) + prompt: str | None = clickdc.option( + '-R', + type=str, + help=f'Prompt format (Default: "{DEFAULT_PROMPT}").', + ) + toolbar: str | None = clickdc.option( + type=str, + help='Toolbar format.', + ) + logfile: TextIOWrapper | None = clickdc.option( + '-l', + type=click.File(mode='a', encoding='utf-8'), + help='Log every query and its results to a file.', + ) + checkpoint: TextIOWrapper | None = clickdc.option( + type=click.File(mode='a', encoding='utf-8'), + help='In batch or --execute mode, log successful queries to a file, and skipped with --resume.', + ) + resume: bool = clickdc.option( + '--resume', + is_flag=True, + help='In batch mode, resume after replaying statements in the --checkpoint file.', + ) + defaults_group_suffix: str | None = clickdc.option( + type=str, + help='Read MySQL config groups with the specified suffix.', + ) + defaults_file: str | None = clickdc.option( + type=click.Path(), + help='Only read MySQL options from the given file.', + ) + myclirc: str = clickdc.option( + type=click.Path(), + default='~/.myclirc', + help='Location of myclirc file.', + ) + auto_vertical_output: bool = clickdc.option( + is_flag=True, + help='Automatically switch to vertical output mode if the result is wider than the terminal width.', + ) + show_warnings: bool | None = clickdc.option( + '--show-warnings/--no-show-warnings', + is_flag=True, + default=None, + clickdc=None, + help='Automatically show warnings after executing a SQL statement.', + ) + table: bool = clickdc.option( + '-t', + is_flag=True, + help='Shorthand for --format=table.', + ) + csv: bool = clickdc.option( + is_flag=True, + help='Shorthand for --format=csv.', + ) + warn: bool | None = clickdc.option( + '--warn/--no-warn', + default=None, + clickdc=None, + help='Warn before running a destructive query.', + ) + local_infile: bool | None = clickdc.option( + type=bool, + is_flag=False, + default=None, + help='Enable/disable LOAD DATA LOCAL INFILE.', + ) + login_path: str | None = clickdc.option( + '-g', + type=str, + help='Read this path from the login file.', + ) + execute: str | None = clickdc.option( + '-e', + type=str, + help='Execute command and quit.', + ) + init_command: str | None = clickdc.option( + type=str, + help='SQL statement to execute after connecting.', + ) + unbuffered: bool | None = clickdc.option( + is_flag=True, + help='Instead of copying every row of data into a buffer, fetch rows as needed, to save memory.', + ) + character_set: str | None = clickdc.option( + '--charset', + '--character-set', + 'character_set', + type=str, + help='Character set for MySQL session.', + ) + batch: str | None = clickdc.option( + type=str, + help='SQL script to execute in batch mode.', + ) + noninteractive: bool = clickdc.option( + is_flag=True, + help="Don't prompt during batch input. Recommended.", + ) + format: str | None = clickdc.option( + type=click.Choice(['default', 'csv', 'tsv', 'table']), + help='Format for batch or --execute output.', + ) + throttle: float = clickdc.option( + type=float, + default=0.0, + help='Pause in seconds between queries in batch mode.', + ) + progress: bool = clickdc.option( + is_flag=True, + help='Show progress on the standard error with --batch.', + ) + use_keyring: str | None = clickdc.option( + type=click.Choice(['true', 'false', 'reset']), + default=None, + help='Store and retrieve passwords from the system keyring: true/false/reset.', + ) + keepalive_ticks: int | None = clickdc.option( + type=int, + help='Send regular keepalive pings to the connection, roughly every seconds.', + ) + checkup: bool = clickdc.option( + is_flag=True, + help='Run a checkup on your configuration.', + ) + + +def get_password_from_file(password_file: str | None) -> str | None: + if not password_file: + return None + try: + with open(password_file) as fp: + return fp.readline().removesuffix('\n') + except FileNotFoundError: + click.secho(f"Password file '{password_file}' not found", err=True, fg='red') + sys.exit(1) + except PermissionError: + click.secho(f"Permission denied reading password file '{password_file}'", err=True, fg='red') + sys.exit(1) + except IsADirectoryError: + click.secho(f"Path '{password_file}' is a directory, not a file", err=True, fg='red') + sys.exit(1) + except Exception as e: + click.secho(f"Error reading password file '{password_file}': {str(e)}", err=True, fg='red') + sys.exit(1) + + +def preprocess_cli_args( + cli_args: CliArgs, + is_valid_connection_scheme: Callable[[str], tuple[bool, str | None]], +) -> int: + if cli_args.database is None and isinstance(cli_args.password, str) and '://' in cli_args.password: + is_valid_scheme, scheme = is_valid_connection_scheme(cli_args.password) + if not is_valid_scheme: + click.secho(f'Error: Unknown connection scheme provided for DSN URI ({scheme}://)', err=True, fg='red') + sys.exit(1) + cli_args.database = cli_args.password + cli_args.password = EMPTY_PASSWORD_FLAG_SENTINEL + + if cli_args.password is None and cli_args.password_file: + password_from_file = get_password_from_file(cli_args.password_file) + if password_from_file is not None: + cli_args.password = password_from_file + + if cli_args.password is None and os.environ.get('MYSQL_PWD') is not None: + cli_args.password = os.environ.get('MYSQL_PWD') + + if cli_args.resume and not cli_args.checkpoint: + click.secho('Error: --resume requires a --checkpoint file.', err=True, fg='red') + sys.exit(1) + + if cli_args.resume and not cli_args.batch: + click.secho('Error: --resume requires a --batch file.', err=True, fg='red') + sys.exit(1) + + if cli_args.verbose and cli_args.quiet: + click.secho('Error: --verbose and --quiet are incompatible.', err=True, fg='red') + sys.exit(1) + elif cli_args.verbose: + return int(cli_args.verbose) + elif cli_args.quiet: + return -1 + return 0 + + @click.command() @clickdc.adddc('cli_args', CliArgs) @click.version_option(mycli_package.__version__, '--version', '-V', help="Output mycli's version.") diff --git a/mycli/main_modes/batch.py b/mycli/main_modes/batch.py index c820f02b..af14dd5f 100644 --- a/mycli/main_modes/batch.py +++ b/mycli/main_modes/batch.py @@ -17,8 +17,8 @@ from mycli.packages.sql_utils import is_destructive if TYPE_CHECKING: - from mycli.cli_args import CliArgs from mycli.client import MyCli + from mycli.main import CliArgs class CheckpointReplayError(Exception): diff --git a/mycli/main_modes/execute.py b/mycli/main_modes/execute.py index 641f8ad8..9de20df3 100644 --- a/mycli/main_modes/execute.py +++ b/mycli/main_modes/execute.py @@ -6,8 +6,8 @@ import click if TYPE_CHECKING: - from mycli.cli_args import CliArgs from mycli.client import MyCli + from mycli.main import CliArgs def main_execute_from_cli(mycli: 'MyCli', cli_args: 'CliArgs') -> int: diff --git a/mycli/main_modes/list_ssh_config.py b/mycli/main_modes/list_ssh_config.py index 14c2ff88..6927580b 100644 --- a/mycli/main_modes/list_ssh_config.py +++ b/mycli/main_modes/list_ssh_config.py @@ -7,8 +7,8 @@ from mycli.packages.ssh_utils import read_ssh_config if TYPE_CHECKING: - from mycli.cli_args import CliArgs from mycli.client import MyCli + from mycli.main import CliArgs def main_list_ssh_config(mycli: 'MyCli', cli_args: 'CliArgs') -> int: diff --git a/test/pytests/test_cli_args.py b/test/pytests/test_cli_args.py deleted file mode 100644 index f9171bdc..00000000 --- a/test/pytests/test_cli_args.py +++ /dev/null @@ -1,175 +0,0 @@ -from __future__ import annotations - -import builtins -from pathlib import Path -from typing import Any - -import click -import pytest - -from mycli import cli_args as cli_args_module -from mycli.cli_args import ( - EMPTY_PASSWORD_FLAG_SENTINEL, - INT_OR_STRING_CLICK_TYPE, - CliArgs, - get_password_from_file, - preprocess_cli_args, -) - - -def valid_connection_scheme(value: str) -> tuple[bool, str | None]: - scheme, _, _ = value.partition('://') - return scheme == 'mysql', scheme or None - - -def test_int_or_string_click_type_accepts_int_string_and_none() -> None: - assert INT_OR_STRING_CLICK_TYPE.convert(7, None, None) == 7 - assert INT_OR_STRING_CLICK_TYPE.convert('secret', None, None) == 'secret' - assert INT_OR_STRING_CLICK_TYPE.convert(None, None, None) is None - - -def test_int_or_string_click_type_rejects_other_values() -> None: - with pytest.raises(click.BadParameter, match='Not a valid password string'): - INT_OR_STRING_CLICK_TYPE.convert(object(), None, None) - - -def test_get_password_from_file_reads_first_line_without_trailing_newline(tmp_path: Path) -> None: - password_file = tmp_path / 'password.txt' - password_file.write_text('secret\nignored\n', encoding='utf8') - - assert get_password_from_file(str(password_file)) == 'secret' - - -def test_get_password_from_file_returns_none_for_missing_path() -> None: - assert get_password_from_file(None) is None - assert get_password_from_file('') is None - - -@pytest.mark.parametrize( - ('exception', 'expected'), - [ - (FileNotFoundError(), "Password file 'secret.txt' not found"), - (PermissionError(), "Permission denied reading password file 'secret.txt'"), - (IsADirectoryError(), "Path 'secret.txt' is a directory, not a file"), - (RuntimeError('boom'), "Error reading password file 'secret.txt': boom"), - ], -) -def test_get_password_from_file_exits_with_error_for_read_failures( - monkeypatch: pytest.MonkeyPatch, - capsys: pytest.CaptureFixture[str], - exception: Exception, - expected: str, -) -> None: - def raise_error(*_args: Any, **_kwargs: Any) -> None: - raise exception - - monkeypatch.setattr(builtins, 'open', raise_error) - - with pytest.raises(SystemExit) as excinfo: - get_password_from_file('secret.txt') - - assert excinfo.value.code == 1 - assert expected in capsys.readouterr().err - - -def test_preprocess_cli_args_moves_dsn_from_password_to_database() -> None: - cli_args = CliArgs() - cli_args.password = 'mysql://user:pass@host/db' - - verbosity = preprocess_cli_args(cli_args, valid_connection_scheme) - - assert verbosity == 0 - assert cli_args.database == 'mysql://user:pass@host/db' - assert cli_args.password == EMPTY_PASSWORD_FLAG_SENTINEL # type: ignore[comparison-overlap] - - -def test_preprocess_cli_args_rejects_unknown_dsn_scheme(capsys: pytest.CaptureFixture[str]) -> None: - cli_args = CliArgs() - cli_args.password = 'postgres://user:pass@host/db' - - with pytest.raises(SystemExit) as excinfo: - preprocess_cli_args(cli_args, valid_connection_scheme) - - assert excinfo.value.code == 1 - assert 'Unknown connection scheme provided for DSN URI (postgres://)' in capsys.readouterr().err - - -def test_preprocess_cli_args_reads_password_file_when_password_missing( - monkeypatch: pytest.MonkeyPatch, -) -> None: - cli_args = CliArgs() - cli_args.password_file = 'secret.txt' - monkeypatch.setattr(cli_args_module, 'get_password_from_file', lambda password_file: f'from:{password_file}') - - assert preprocess_cli_args(cli_args, valid_connection_scheme) == 0 - assert cli_args.password == 'from:secret.txt' - - -def test_preprocess_cli_args_uses_mysql_pwd_when_password_and_file_missing(monkeypatch: pytest.MonkeyPatch) -> None: - cli_args = CliArgs() - monkeypatch.setenv('MYSQL_PWD', 'env-secret') - - assert preprocess_cli_args(cli_args, valid_connection_scheme) == 0 - assert cli_args.password == 'env-secret' - - -def test_preprocess_cli_args_prefers_existing_password_over_mysql_pwd(monkeypatch: pytest.MonkeyPatch) -> None: - cli_args = CliArgs() - cli_args.password = 'cli-secret' - monkeypatch.setenv('MYSQL_PWD', 'env-secret') - - assert preprocess_cli_args(cli_args, valid_connection_scheme) == 0 - assert cli_args.password == 'cli-secret' - - -@pytest.mark.parametrize( - ('checkpoint', 'batch', 'expected'), - [ - (None, 'batch.sql', 'Error: --resume requires a --checkpoint file.'), - (object(), None, 'Error: --resume requires a --batch file.'), - ], -) -def test_preprocess_cli_args_validates_resume_requirements( - capsys: pytest.CaptureFixture[str], - checkpoint: object | None, - batch: str | None, - expected: str, -) -> None: - cli_args = CliArgs() - cli_args.resume = True - cli_args.checkpoint = checkpoint # type: ignore[assignment] - cli_args.batch = batch - - with pytest.raises(SystemExit) as excinfo: - preprocess_cli_args(cli_args, valid_connection_scheme) - - assert excinfo.value.code == 1 - assert expected in capsys.readouterr().err - - -def test_preprocess_cli_args_rejects_verbose_and_quiet(capsys: pytest.CaptureFixture[str]) -> None: - cli_args = CliArgs() - cli_args.verbose = 1 - cli_args.quiet = True - - with pytest.raises(SystemExit) as excinfo: - preprocess_cli_args(cli_args, valid_connection_scheme) - - assert excinfo.value.code == 1 - assert 'Error: --verbose and --quiet are incompatible.' in capsys.readouterr().err - - -@pytest.mark.parametrize( - ('verbose', 'quiet', 'expected'), - [ - (2, False, 2), - (0, True, -1), - (0, False, 0), - ], -) -def test_preprocess_cli_args_returns_cli_verbosity(verbose: int, quiet: bool, expected: int) -> None: - cli_args = CliArgs() - cli_args.verbose = verbose - cli_args.quiet = quiet - - assert preprocess_cli_args(cli_args, valid_connection_scheme) == expected diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index 925bdd25..d65491a3 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -1,10 +1,13 @@ # type: ignore +from __future__ import annotations +import builtins from collections import namedtuple from contextlib import redirect_stderr, redirect_stdout import csv import io import os +from pathlib import Path import shutil import sys from tempfile import NamedTemporaryFile @@ -33,7 +36,15 @@ DEFAULT_USER, TEST_DATABASE, ) -from mycli.main import EMPTY_PASSWORD_FLAG_SENTINEL, MyCli, click_entrypoint +from mycli.main import ( + EMPTY_PASSWORD_FLAG_SENTINEL, + INT_OR_STRING_CLICK_TYPE, + CliArgs, + MyCli, + click_entrypoint, + get_password_from_file, + preprocess_cli_args, +) import mycli.main_modes.repl as repl_mode import mycli.output as output_module import mycli.packages.special @@ -2223,7 +2234,7 @@ def test_null_string_config(monkeypatch): ) myclirc.flush() args = CLI_ARGS_WITHOUT_DB + ['--myclirc', myclirc.name, '--format=table', '--execute', 'SELECT NULL'] - result = runner.invoke(mycli.main.click_entrypoint, args=args) + result = runner.invoke(main.click_entrypoint, args=args) assert '' in result.output assert '' not in result.output @@ -2480,3 +2491,161 @@ def __init__(self, **kwargs: Any) -> None: assert cli.sandbox_mode is True assert any('password has expired' in message for message in echo_calls) + + +def valid_connection_scheme(value: str) -> tuple[bool, str | None]: + scheme, _, _ = value.partition('://') + return scheme == 'mysql', scheme or None + + +def test_int_or_string_click_type_accepts_int_string_and_none() -> None: + assert INT_OR_STRING_CLICK_TYPE.convert(7, None, None) == 7 + assert INT_OR_STRING_CLICK_TYPE.convert('secret', None, None) == 'secret' + assert INT_OR_STRING_CLICK_TYPE.convert(None, None, None) is None + + +def test_int_or_string_click_type_rejects_other_values() -> None: + with pytest.raises(click.BadParameter, match='Not a valid password string'): + INT_OR_STRING_CLICK_TYPE.convert(object(), None, None) + + +def test_get_password_from_file_reads_first_line_without_trailing_newline(tmp_path: Path) -> None: + password_file = tmp_path / 'password.txt' + password_file.write_text('secret\nignored\n', encoding='utf8') + + assert get_password_from_file(str(password_file)) == 'secret' + + +def test_get_password_from_file_returns_none_for_missing_path() -> None: + assert get_password_from_file(None) is None + assert get_password_from_file('') is None + + +@pytest.mark.parametrize( + ('exception', 'expected'), + [ + (FileNotFoundError(), "Password file 'secret.txt' not found"), + (PermissionError(), "Permission denied reading password file 'secret.txt'"), + (IsADirectoryError(), "Path 'secret.txt' is a directory, not a file"), + (RuntimeError('boom'), "Error reading password file 'secret.txt': boom"), + ], +) +def test_get_password_from_file_exits_with_error_for_read_failures( + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], + exception: Exception, + expected: str, +) -> None: + def raise_error(*_args: Any, **_kwargs: Any) -> None: + raise exception + + monkeypatch.setattr(builtins, 'open', raise_error) + + with pytest.raises(SystemExit) as excinfo: + get_password_from_file('secret.txt') + + assert excinfo.value.code == 1 + assert expected in capsys.readouterr().err + + +def test_preprocess_cli_args_moves_dsn_from_password_to_database() -> None: + cli_args = CliArgs() + cli_args.password = 'mysql://user:pass@host/db' + + verbosity = preprocess_cli_args(cli_args, valid_connection_scheme) + + assert verbosity == 0 + assert cli_args.database == 'mysql://user:pass@host/db' + assert cli_args.password == EMPTY_PASSWORD_FLAG_SENTINEL # type: ignore[comparison-overlap] + + +def test_preprocess_cli_args_rejects_unknown_dsn_scheme(capsys: pytest.CaptureFixture[str]) -> None: + cli_args = CliArgs() + cli_args.password = 'postgres://user:pass@host/db' + + with pytest.raises(SystemExit) as excinfo: + preprocess_cli_args(cli_args, valid_connection_scheme) + + assert excinfo.value.code == 1 + assert 'Unknown connection scheme provided for DSN URI (postgres://)' in capsys.readouterr().err + + +def test_preprocess_cli_args_reads_password_file_when_password_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + cli_args = CliArgs() + cli_args.password_file = 'secret.txt' + monkeypatch.setattr(main, 'get_password_from_file', lambda password_file: f'from:{password_file}') + + assert preprocess_cli_args(cli_args, valid_connection_scheme) == 0 + assert cli_args.password == 'from:secret.txt' + + +def test_preprocess_cli_args_uses_mysql_pwd_when_password_and_file_missing(monkeypatch: pytest.MonkeyPatch) -> None: + cli_args = CliArgs() + monkeypatch.setenv('MYSQL_PWD', 'env-secret') + + assert preprocess_cli_args(cli_args, valid_connection_scheme) == 0 + assert cli_args.password == 'env-secret' + + +def test_preprocess_cli_args_prefers_existing_password_over_mysql_pwd(monkeypatch: pytest.MonkeyPatch) -> None: + cli_args = CliArgs() + cli_args.password = 'cli-secret' + monkeypatch.setenv('MYSQL_PWD', 'env-secret') + + assert preprocess_cli_args(cli_args, valid_connection_scheme) == 0 + assert cli_args.password == 'cli-secret' + + +@pytest.mark.parametrize( + ('checkpoint', 'batch', 'expected'), + [ + (None, 'batch.sql', 'Error: --resume requires a --checkpoint file.'), + (object(), None, 'Error: --resume requires a --batch file.'), + ], +) +def test_preprocess_cli_args_validates_resume_requirements( + capsys: pytest.CaptureFixture[str], + checkpoint: object | None, + batch: str | None, + expected: str, +) -> None: + cli_args = CliArgs() + cli_args.resume = True + cli_args.checkpoint = checkpoint # type: ignore[assignment] + cli_args.batch = batch + + with pytest.raises(SystemExit) as excinfo: + preprocess_cli_args(cli_args, valid_connection_scheme) + + assert excinfo.value.code == 1 + assert expected in capsys.readouterr().err + + +def test_preprocess_cli_args_rejects_verbose_and_quiet(capsys: pytest.CaptureFixture[str]) -> None: + cli_args = CliArgs() + cli_args.verbose = 1 + cli_args.quiet = True + + with pytest.raises(SystemExit) as excinfo: + preprocess_cli_args(cli_args, valid_connection_scheme) + + assert excinfo.value.code == 1 + assert 'Error: --verbose and --quiet are incompatible.' in capsys.readouterr().err + + +@pytest.mark.parametrize( + ('verbose', 'quiet', 'expected'), + [ + (2, False, 2), + (0, True, -1), + (0, False, 0), + ], +) +def test_preprocess_cli_args_returns_cli_verbosity(verbose: int, quiet: bool, expected: int) -> None: + cli_args = CliArgs() + cli_args.verbose = verbose + cli_args.quiet = quiet + + assert preprocess_cli_args(cli_args, valid_connection_scheme) == expected diff --git a/test/pytests/test_main_regression.py b/test/pytests/test_main_regression.py index 42d3efdd..4d1168cd 100644 --- a/test/pytests/test_main_regression.py +++ b/test/pytests/test_main_regression.py @@ -38,9 +38,9 @@ import pytest from mycli import main -from mycli.cli_args import IntOrStringClickParamType import mycli.client_connection import mycli.key_bindings +from mycli.main import IntOrStringClickParamType import mycli.output as output_module from mycli.packages.sqlresult import SQLResult from test.utils import ( # type: ignore[attr-defined]