Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hailctl] Autocomplete for hailctl config {get,set,unset} #13224

Merged
merged 22 commits into from Aug 14, 2023
4 changes: 2 additions & 2 deletions hail/python/hail/backend/backend.py
Expand Up @@ -4,7 +4,7 @@
import pkg_resources
import zipfile

from hailtop.config.user_config import configuration_of
from hailtop.config.user_config import unchecked_configuration_of
from hailtop.fs.fs import FS

from ..builtin_references import BUILTIN_REFERENCE_RESOURCE_PATHS
Expand Down Expand Up @@ -197,7 +197,7 @@ def persist_expression(self, expr: Expression) -> Expression:

def _initialize_flags(self, initial_flags: Dict[str, str]) -> None:
self.set_flags(**{
k: configuration_of('query', k, None, default, deprecated_envvar=deprecated_envvar)
k: unchecked_configuration_of('query', k, None, default, deprecated_envvar=deprecated_envvar)
for k, (deprecated_envvar, default) in Backend._flags_env_vars_and_defaults.items()
if k not in initial_flags
}, **initial_flags)
Expand Down
20 changes: 10 additions & 10 deletions hail/python/hail/backend/service_backend.py
Expand Up @@ -19,7 +19,7 @@
from hail.ir.renderer import CSERenderer

from hailtop import yamlx
from hailtop.config import (configuration_of, get_remote_tmpdir)
from hailtop.config import (ConfigVariable, configuration_of, get_remote_tmpdir)
from hailtop.utils import async_to_blocking, secret_alnum_string, TransientError, Timings, am_i_interactive, retry_transient_errors
from hailtop.utils.rich_progress_bar import BatchProgressBar
from hailtop.batch_client import client as hb
Expand Down Expand Up @@ -206,7 +206,7 @@ async def create(*,
token: Optional[str] = None,
regions: Optional[List[str]] = None,
gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None):
billing_project = configuration_of('batch', 'billing_project', billing_project, None)
billing_project = configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, billing_project, None)
if billing_project is None:
raise ValueError(
"No billing project. Call 'init_batch' with the billing "
Expand All @@ -224,17 +224,17 @@ async def create(*,
batch_attributes: Dict[str, str] = dict()
remote_tmpdir = get_remote_tmpdir('ServiceBackend', remote_tmpdir=remote_tmpdir)

jar_url = configuration_of('query', 'jar_url', jar_url, None)
jar_url = configuration_of(ConfigVariable.QUERY_JAR_URL, jar_url, None)
jar_spec = GitRevision(revision()) if jar_url is None else JarUrl(jar_url)

driver_cores = configuration_of('query', 'batch_driver_cores', driver_cores, None)
driver_memory = configuration_of('query', 'batch_driver_memory', driver_memory, None)
worker_cores = configuration_of('query', 'batch_worker_cores', worker_cores, None)
worker_memory = configuration_of('query', 'batch_worker_memory', worker_memory, None)
name_prefix = configuration_of('query', 'name_prefix', name_prefix, '')
driver_cores = configuration_of(ConfigVariable.QUERY_BATCH_DRIVER_CORES, driver_cores, None)
driver_memory = configuration_of(ConfigVariable.QUERY_BATCH_DRIVER_MEMORY, driver_memory, None)
worker_cores = configuration_of(ConfigVariable.QUERY_BATCH_WORKER_CORES, worker_cores, None)
worker_memory = configuration_of(ConfigVariable.QUERY_BATCH_WORKER_MEMORY, worker_memory, None)
name_prefix = configuration_of(ConfigVariable.QUERY_NAME_PREFIX, name_prefix, '')

if regions is None:
regions_from_conf = configuration_of('batch', 'regions', regions, None)
regions_from_conf = configuration_of(ConfigVariable.BATCH_REGIONS, regions, None)
if regions_from_conf is not None:
assert isinstance(regions_from_conf, str)
regions = regions_from_conf.split(',')
Expand All @@ -245,7 +245,7 @@ async def create(*,
assert len(regions) > 0, regions

if disable_progress_bar is None:
disable_progress_bar_str = configuration_of('query', 'disable_progress_bar', None, None)
disable_progress_bar_str = configuration_of(ConfigVariable.QUERY_DISABLE_PROGRESS_BAR, None, None)
if disable_progress_bar_str is None:
disable_progress_bar = not am_i_interactive()
else:
Expand Down
11 changes: 11 additions & 0 deletions hail/python/hail/docs/install/macosx.rst
Expand Up @@ -14,3 +14,14 @@ Install Hail on Mac OS X
- Install Python 3.9 or later. We recommend `Miniconda <https://docs.conda.io/en/latest/miniconda.html#macosx-installers>`__.
- Open Terminal.app and execute ``pip install hail``. If this command fails with a message about "Rust", please try this instead: ``pip install hail --only-binary=:all:``.
- `Run your first Hail query! <try.rst>`__

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
hailctl Autocompletion (Optional)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

- Install autocompletion with ``hailctl --install-completion zsh``
- Ensure this line is in your zsh config file (~/.zshrc) and then reload your terminal.

.. code-block::

autoload -Uz compinit && compinit
4 changes: 2 additions & 2 deletions hail/python/hail/utils/java.py
Expand Up @@ -2,11 +2,11 @@
import sys
import re

from hailtop.config import configuration_of
from hailtop.config import ConfigVariable, configuration_of


def choose_backend(backend: Optional[str] = None) -> str:
return configuration_of('query', 'backend', backend, 'spark')
return configuration_of(ConfigVariable.QUERY_BACKEND, backend, 'spark')


class HailUserError(Exception):
Expand Down
6 changes: 3 additions & 3 deletions hail/python/hailtop/aiocloud/aiogoogle/user_config.py
Expand Up @@ -6,7 +6,7 @@
from dataclasses import dataclass


from hailtop.config.user_config import configuration_of
from hailtop.config.user_config import ConfigVariable, configuration_of


GCSRequesterPaysConfiguration = Union[str, Tuple[str, List[str]]]
Expand All @@ -19,8 +19,8 @@ def get_gcs_requester_pays_configuration(
if gcs_requester_pays_configuration:
return gcs_requester_pays_configuration

project = configuration_of('gcs_requester_pays', 'project', None, None)
buckets = configuration_of('gcs_requester_pays', 'buckets', None, None)
project = configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_PROJECT, None, None)
buckets = configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_BUCKETS, None, None)

spark_conf = get_spark_conf_gcs_requester_pays_configuration()

Expand Down
6 changes: 3 additions & 3 deletions hail/python/hailtop/batch/backend.py
Expand Up @@ -15,7 +15,7 @@
from rich.progress import track

from hailtop import pip_version
from hailtop.config import configuration_of, get_deploy_config, get_remote_tmpdir
from hailtop.config import ConfigVariable, configuration_of, get_deploy_config, get_remote_tmpdir
from hailtop.utils.rich_progress_bar import SimpleRichProgressBar
from hailtop.utils import parse_docker_image_reference, async_to_blocking, bounded_gather, url_scheme
from hailtop.batch.hail_genetics_images import HAIL_GENETICS_IMAGES, hailgenetics_python_dill_image_for_current_python_version
Expand Down Expand Up @@ -474,7 +474,7 @@ def __init__(self,
warnings.warn('Use of deprecated positional argument \'bucket\' in ServiceBackend(). Specify \'bucket\' as a keyword argument instead.')
bucket = args[1]

billing_project = configuration_of('batch', 'billing_project', billing_project, None)
billing_project = configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, billing_project, None)
if billing_project is None:
raise ValueError(
'the billing_project parameter of ServiceBackend must be set '
Expand All @@ -501,7 +501,7 @@ def __init__(self,
self.__fs: RouterAsyncFS = RouterAsyncFS(gcs_kwargs=gcs_kwargs)

if regions is None:
regions_from_conf = configuration_of('batch', 'regions', None, None)
regions_from_conf = configuration_of(ConfigVariable.BATCH_REGIONS, None, None)
if regions_from_conf is not None:
assert isinstance(regions_from_conf, str)
regions = regions_from_conf.split(',')
Expand Down
4 changes: 2 additions & 2 deletions hail/python/hailtop/batch/batch.py
Expand Up @@ -10,7 +10,7 @@
from hailtop.aiocloud.aioazure.fs import AzureAsyncFS
from hailtop.aiotools.router_fs import RouterAsyncFS
import hailtop.batch_client.client as _bc
from hailtop.config import configuration_of
from hailtop.config import ConfigVariable, configuration_of

from . import backend as _backend, job, resource as _resource # pylint: disable=cyclic-import
from .exceptions import BatchException
Expand Down Expand Up @@ -167,7 +167,7 @@ def __init__(self,
if backend:
self._backend = backend
else:
backend_config = configuration_of('batch', 'backend', None, 'local')
backend_config = configuration_of(ConfigVariable.BATCH_BACKEND, None, 'local')
if backend_config == 'service':
self._backend = _backend.ServiceBackend()
else:
Expand Down
5 changes: 4 additions & 1 deletion hail/python/hailtop/config/__init__.py
@@ -1,12 +1,15 @@
from .user_config import (get_user_config, get_user_config_path,
get_remote_tmpdir, configuration_of)
from .deploy_config import get_deploy_config, DeployConfig
from .variables import ConfigVariable, config_variables

__all__ = [
'get_deploy_config',
'get_user_config',
'get_user_config_path',
'get_remote_tmpdir',
'DeployConfig',
'configuration_of'
'ConfigVariable',
'configuration_of',
'config_variables',
]
30 changes: 20 additions & 10 deletions hail/python/hailtop/config/user_config.py
Expand Up @@ -5,6 +5,8 @@
import warnings
from pathlib import Path

from .variables import ConfigVariable, config_variables

user_config = None


Expand Down Expand Up @@ -36,15 +38,12 @@ def get_user_config() -> configparser.ConfigParser:
T = TypeVar('T')


def configuration_of(section: str,
option: str,
explicit_argument: Optional[T],
fallback: T,
*,
deprecated_envvar: Optional[str] = None) -> Union[str, T]:
assert VALID_SECTION_AND_OPTION_RE.fullmatch(section), (section, option)
assert VALID_SECTION_AND_OPTION_RE.fullmatch(option), (section, option)

def unchecked_configuration_of(section: str,
option: str,
explicit_argument: Optional[T],
fallback: T,
*,
deprecated_envvar: Optional[str] = None) -> Union[str, T]:
if explicit_argument is not None:
return explicit_argument

Expand All @@ -69,6 +68,17 @@ def configuration_of(section: str,
return fallback


def configuration_of(config_variable: ConfigVariable,
explicit_argument: Optional[T],
fallback: T,
*,
deprecated_envvar: Optional[str] = None) -> Union[str, T]:
config_variable_info = config_variables()[config_variable]
section = config_variable_info.section
option = config_variable_info.option
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of storing the section and option again in a map, you can do config_variable.value.split("/")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we went with your suggestion, we'd have to have something like this. Which do you prefer?

if '/' in config_variable.value:
    section, option = config_variable.value.split('/')
else:
    section = 'global'
    option = config_variable.value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we went with your suggestion, we'd have to have something like this. Which do you prefer?

if '/' in config_variable.value:
    section, option = config_variable.value.split('/')
else:
    section = 'global'
    option = config_variable.value

I would prefer this. If I understand correctly it would allow us to delete the repeated section and option information in config_variables

return unchecked_configuration_of(section, option, explicit_argument, fallback, deprecated_envvar=deprecated_envvar)


def get_remote_tmpdir(caller_name: str,
*,
bucket: Optional[str] = None,
Expand All @@ -87,7 +97,7 @@ def get_remote_tmpdir(caller_name: str,
raise ValueError(f'Cannot specify both \'remote_tmpdir\' and \'bucket\' in {caller_name}(...). Specify \'remote_tmpdir\' as a keyword argument instead.')

if bucket is None and remote_tmpdir is None:
remote_tmpdir = configuration_of('batch', 'remote_tmpdir', None, None)
remote_tmpdir = configuration_of(ConfigVariable.BATCH_REMOTE_TMPDIR, None, None)

if remote_tmpdir is None:
if bucket is None:
Expand Down
144 changes: 144 additions & 0 deletions hail/python/hailtop/config/variables.py
@@ -0,0 +1,144 @@
from collections import namedtuple
from enum import Enum
import re


_config_variables = None

ConfigVariableInfo = namedtuple('ConfigVariable', ['section', 'option', 'help_msg', 'validation'])


class ConfigVariable(str, Enum):
DOMAIN = 'domain'
GCS_REQUESTER_PAYS_PROJECT = 'gcs_requester_pays/project',
GCS_REQUESTER_PAYS_BUCKETS = 'gcs_requester_pays/buckets',
BATCH_BUCKET = 'batch/bucket'
BATCH_REMOTE_TMPDIR = 'batch/remote_tmpdir'
BATCH_REGIONS = 'batch/regions',
BATCH_BILLING_PROJECT = 'batch/billing_project'
BATCH_BACKEND = 'batch/backend'
QUERY_BACKEND = 'query/backend'
QUERY_JAR_URL = 'query/jar_url'
QUERY_BATCH_DRIVER_CORES = 'query/batch_driver_cores'
QUERY_BATCH_WORKER_CORES = 'query/batch_worker_cores'
QUERY_BATCH_DRIVER_MEMORY = 'query/batch_driver_memory'
QUERY_BATCH_WORKER_MEMORY = 'query/batch_worker_memory'
QUERY_NAME_PREFIX = 'query/name_prefix'
QUERY_DISABLE_PROGRESS_BAR = 'query/disable_progress_bar'


def config_variables():
from hailtop.batch_client.parse import CPU_REGEXPAT, MEMORY_REGEXPAT # pylint: disable=import-outside-toplevel
from hailtop.fs.router_fs import RouterAsyncFS # pylint: disable=import-outside-toplevel

global _config_variables

if _config_variables is None:
_config_variables = {
ConfigVariable.DOMAIN: ConfigVariableInfo(
section='global',
option='domain',
help_msg='Domain of the Batch service',
validation=(lambda x: re.fullmatch(r'.+\..+', x) is not None, 'should be valid domain'),
),
ConfigVariable.GCS_REQUESTER_PAYS_PROJECT: ConfigVariableInfo(
section='gcs_requester_pays',
option='project',
help_msg='Project when using requester pays buckets in GCS',
validation=(lambda x: re.fullmatch(r'[^:/\s]+', x) is not None, 'should be valid GCS project name'),
),
ConfigVariable.GCS_REQUESTER_PAYS_BUCKETS: ConfigVariableInfo(
section='gcs_requester_pays',
option='buckets',
help_msg='Allowed buckets when using requester pays in GCS',
validation=(
lambda x: re.fullmatch(r'[^:/\s]+(,[^:/\s]+)*', x) is not None,
'should be comma separated list of bucket names'),
),
ConfigVariable.BATCH_BUCKET: ConfigVariableInfo(
section='batch',
option='bucket',
help_msg='Deprecated - Name of GCS bucket to use as a temporary scratch directory',
validation=(lambda x: re.fullmatch(r'[^:/\s]+', x) is not None,
'should be valid Google Bucket identifier, with no gs:// prefix'),
),
ConfigVariable.BATCH_REMOTE_TMPDIR: ConfigVariableInfo(
section='batch',
option='remote_tmpdir',
help_msg='Cloud storage URI to use as a temporary scratch directory',
validation=(RouterAsyncFS.valid_url, 'should be valid cloud storage URI such as gs://my-bucket/batch-tmp/'),
),
ConfigVariable.BATCH_REGIONS: ConfigVariableInfo(
section='batch',
option='regions',
help_msg='Comma-separated list of regions to run jobs in',
validation=(
lambda x: re.fullmatch(r'[^\s]+(,[^\s]+)*', x) is not None, 'should be comma separated list of regions'),
),
ConfigVariable.BATCH_BILLING_PROJECT: ConfigVariableInfo(
section='batch',
option='billing_project',
help_msg='Batch billing project',
validation=(lambda x: re.fullmatch(r'[^:/\s]+', x) is not None, 'should be valid Batch billing project name'),
),
ConfigVariable.BATCH_BACKEND: ConfigVariableInfo(
section='batch',
option='backend',
help_msg='Backend to use. One of local or service.',
validation=(lambda x: x in ('local', 'service'), 'should be one of "local" or "service"'),
),
ConfigVariable.QUERY_BACKEND: ConfigVariableInfo(
section='query',
option='backend',
help_msg='Backend to use for Hail Query. One of spark, local, batch.',
validation=(lambda x: x in ('local', 'spark', 'batch'), 'should be one of "local", "spark", or "batch"'),
),
ConfigVariable.QUERY_JAR_URL: ConfigVariableInfo(
section='query',
option='jar_url',
help_msg='Cloud storage URI to a Query JAR',
validation=(RouterAsyncFS.valid_url, 'should be valid cloud storage URI such as gs://my-bucket/jars/sha.jar')
),
ConfigVariable.QUERY_BATCH_DRIVER_CORES: ConfigVariableInfo(
section='query',
option='batch_driver_cores',
help_msg='Cores specification for the query driver',
validation=(lambda x: re.fullmatch(CPU_REGEXPAT, x) is not None,
'should be an integer which is a power of two from 1 to 16 inclusive'),
),
ConfigVariable.QUERY_BATCH_WORKER_CORES: ConfigVariableInfo(
section='query',
option='batch_worker_cores',
help_msg='Cores specification for the query worker',
validation=(lambda x: re.fullmatch(CPU_REGEXPAT, x) is not None,
'should be an integer which is a power of two from 1 to 16 inclusive'),
),
ConfigVariable.QUERY_BATCH_DRIVER_MEMORY: ConfigVariableInfo(
section='query',
option='batch_driver_memory',
help_msg='Memory specification for the query driver',
validation=(lambda x: re.fullmatch(MEMORY_REGEXPAT, x) is not None or x in ('standard', 'lowmem', 'highmem'),
'should be a valid string specifying memory "[+]?((?:[0-9]*[.])?[0-9]+)([KMGTP][i]?)?B?" or one of standard, lowmem, highmem'),
),
ConfigVariable.QUERY_BATCH_WORKER_MEMORY: ConfigVariableInfo(
section='query',
option='batch_worker_memory',
help_msg='Memory specification for the query worker',
validation=(lambda x: re.fullmatch(MEMORY_REGEXPAT, x) is not None or x in ('standard', 'lowmem', 'highmem'),
'should be a valid string specifying memory "[+]?((?:[0-9]*[.])?[0-9]+)([KMGTP][i]?)?B?" or one of standard, lowmem, highmem'),
),
ConfigVariable.QUERY_NAME_PREFIX: ConfigVariableInfo(
section='query',
option='name_prefix',
help_msg='Name used when displaying query progress in a progress bar',
validation=(lambda x: re.fullmatch(r'[^\s]+', x) is not None, 'should be single word without spaces'),
),
ConfigVariable.QUERY_DISABLE_PROGRESS_BAR: ConfigVariableInfo(
section='query',
option='disable_progress_bar',
help_msg='Disable the progress bar with a value of 1. Enable the progress bar with a value of 0',
validation=(lambda x: x in ('0', '1'), 'should be a value of 0 or 1'),
),
}

return _config_variables