From 957bd4804174dfbfb99f6f06e51495333cd8bf84 Mon Sep 17 00:00:00 2001 From: favilo Date: Thu, 20 Jun 2024 18:12:57 -0700 Subject: [PATCH 01/26] Modify rally to allow multiple cars with complex configuration --- esrally/mechanic/team.py | 44 ++++++++++++------ esrally/utils/io.py | 7 +-- esrally/utils/modules.py | 98 ++++++++++++++++++++++------------------ esrally/utils/process.py | 4 +- 4 files changed, 92 insertions(+), 61 deletions(-) diff --git a/esrally/mechanic/team.py b/esrally/mechanic/team.py index 6ea7cabc1..aab92ccb4 100644 --- a/esrally/mechanic/team.py +++ b/esrally/mechanic/team.py @@ -19,6 +19,8 @@ import logging import os from enum import Enum +from typing import Any, Collection, Mapping, Union +from types import ModuleType import tabulate @@ -53,7 +55,7 @@ def __init__(self, root_path, entry_point): self.root_path = root_path self.entry_point = entry_point - root_path = None + root_paths = [] # preserve order as we append to existing config files later during provisioning. all_config_paths = [] all_config_base_vars = {} @@ -67,11 +69,11 @@ def __init__(self, root_path, entry_point): for p in descriptor.root_paths: # probe whether we have a root path if BootstrapHookHandler(Component(root_path=p, entry_point=Car.entry_point)).can_load(): - if not root_path: - root_path = p + if p not in root_paths: + root_paths.append(p) # multiple cars are based on the same hook - elif root_path != p: - raise exceptions.SystemSetupError(f"Invalid car: {name}. Multiple bootstrap hooks are forbidden.") + # elif root_paths != p: + # raise exceptions.SystemSetupError(f"Invalid car: {name}. Multiple bootstrap hooks are forbidden.") all_config_base_vars.update(descriptor.config_base_variables) all_car_vars.update(descriptor.variables) @@ -82,7 +84,7 @@ def __init__(self, root_path, entry_point): variables.update(all_config_base_vars) variables.update(all_car_vars) - return Car(name, root_path, all_config_paths, variables) + return Car(name, root_paths, all_config_paths, variables) def list_plugins(cfg: types.Config): @@ -235,12 +237,18 @@ class Car: # name of the initial Python file to load for cars. entry_point = "config" - def __init__(self, names, root_path, config_paths, variables=None): + def __init__( + self, + names: Collection[str], + root_path: Union[str, Collection[str]], + config_paths: Collection[str], + variables: Mapping[str, Any] = None, + ): """ Creates new settings for a benchmark candidate. :param names: Descriptive name(s) for this car. - :param root_path: The root path from which bootstrap hooks should be loaded if any. May be ``None``. + :param root_path: The root path(s) from which bootstrap hooks should be loaded if any. May be ``[]``. :param config_paths: A non-empty list of paths where the raw config can be found. :param variables: A dict containing variable definitions that need to be replaced. """ @@ -250,7 +258,11 @@ def __init__(self, names, root_path, config_paths, variables=None): self.names = [names] else: self.names = names - self.root_path = root_path + + if isinstance(root_path, str): + self.root_path = [root_path] + else: + self.root_path = root_path self.config_paths = config_paths self.variables = variables @@ -481,7 +493,11 @@ def __init__(self, component, loader_class=modules.ComponentLoader): self.component = component # Don't allow the loader to recurse. The subdirectories may contain Elasticsearch specific files which we do not want to add to # Rally's Python load path. We may need to define a more advanced strategy in the future. - self.loader = loader_class(root_path=self.component.root_path, component_entry_point=self.component.entry_point, recurse=False) + if isinstance(self.component.root_path, list): + root_path = self.component.root_path + else: + root_path = [self.component.root_path] + self.loader = loader_class(root_path=root_path, component_entry_point=self.component.entry_point, recurse=False) self.hooks = {} self.logger = logging.getLogger(__name__) @@ -489,17 +505,19 @@ def can_load(self): return self.loader.can_load() def load(self): - root_module = self.loader.load() + root_modules: Collection[ModuleType] = self.loader.load() try: # every module needs to have a register() method - root_module.register(self) + for module in root_modules: + module.register(self) except exceptions.RallyError: # just pass our own exceptions transparently. raise except BaseException: msg = f"Could not load bootstrap hooks in [{self.loader.root_path}]" self.logger.exception(msg) - raise exceptions.SystemSetupError(msg) + raise + # raise exceptions.SystemSetupError(msg) def register(self, phase, hook): self.logger.info("Registering bootstrap hook [%s] for phase [%s] in component [%s]", hook.__name__, phase, self.component.name) diff --git a/esrally/utils/io.py b/esrally/utils/io.py index 7452e2b71..6e4ff4f74 100644 --- a/esrally/utils/io.py +++ b/esrally/utils/io.py @@ -26,6 +26,7 @@ import subprocess import tarfile import zipfile +from typing import AnyStr import zstandard @@ -379,15 +380,15 @@ def _do_decompress(target_directory, compressed_file): # just in a dedicated method to ease mocking -def dirname(path): +def dirname(path: AnyStr): return os.path.dirname(path) -def basename(path): +def basename(path: AnyStr): return os.path.basename(path) -def exists(path): +def exists(path: AnyStr): return os.path.exists(path) diff --git a/esrally/utils/modules.py b/esrally/utils/modules.py index 728b55d3a..9b5dd2aa3 100644 --- a/esrally/utils/modules.py +++ b/esrally/utils/modules.py @@ -19,6 +19,8 @@ import logging import os import sys +from typing import Collection, Union +from types import ModuleType from esrally import exceptions from esrally.utils import io @@ -34,36 +36,40 @@ class ComponentLoader: """ - def __init__(self, root_path, component_entry_point, recurse=True): + def __init__(self, root_path: Union[str, Collection[str]], component_entry_point: str, recurse: bool = True): """ Creates a new component loader. - :param root_path: An absolute path to a directory which contains the component entry point. + :param root_path: An absolute path or list of paths to a directory which contains the component entry point. :param component_entry_point: The name of the component entry point. A corresponding file with the extension ".py" must exist in the ``root_path``. :param recurse: Search recursively for modules but ignore modules starting with "_" (Default: ``True``). """ - self.root_path = root_path + self.root_path: Collection[str] = root_path if isinstance(root_path, list) else [root_path] self.component_entry_point = component_entry_point self.recurse = recurse self.logger = logging.getLogger(__name__) - def _modules(self, module_paths, component_name): + def _modules(self, module_paths: Collection[str], component_name: str, root_path: str): for path in module_paths: for filename in os.listdir(path): name, ext = os.path.splitext(filename) if ext.endswith(".py"): - root_relative_path = os.path.join(path, name)[len(self.root_path) + len(os.path.sep) :] + file_absolute_path = os.path.join(path, filename) + root_absolute_path = os.path.join(path, name) + root_relative_path = root_absolute_path[len(root_path) + len(os.path.sep) :] module_name = "%s.%s" % (component_name, root_relative_path.replace(os.path.sep, ".")) - yield module_name + yield module_name, file_absolute_path - def _load_component(self, component_name, module_dirs): + def _load_component(self, component_name: str, module_dirs: Collection[str], root_path: str): # precondition: A module with this name has to exist provided that the caller has called #can_load() before. root_module_name = "%s.%s" % (component_name, self.component_entry_point) - for p in self._modules(module_dirs, component_name): - self.logger.debug("Loading module [%s]", p) - m = importlib.import_module(p) - if p == root_module_name: + for name, p in self._modules(module_dirs, component_name, root_path): + self.logger.debug(f"Loading module [{name}]: {p}") + spec = importlib.util.spec_from_file_location(name, p) + m = importlib.util.module_from_spec(spec) + spec.loader.exec_module(m) + if name == root_module_name: root_module = m return root_module @@ -71,41 +77,47 @@ def can_load(self): """ :return: True iff the component entry point could be found. """ - return self.root_path and os.path.exists(os.path.join(self.root_path, "%s.py" % self.component_entry_point)) + return self.root_path and all( + os.path.exists(os.path.join(root_path, "%s.py" % self.component_entry_point)) for root_path in self.root_path + ) - def load(self): + def load(self) -> Collection[ModuleType]: """ - Loads a component with the given component entry point. + Loads components with the given component entry point. Precondition: ``ComponentLoader#can_load() == True``. - :return: The root module. + :return: The root modules. """ - component_name = io.basename(self.root_path) - self.logger.info("Loading component [%s] from [%s]", component_name, self.root_path) - module_dirs = [] - # search all paths within this directory for modules but exclude all directories starting with "_" - if self.recurse: - for dirpath, dirs, _ in os.walk(self.root_path): - module_dirs.append(dirpath) - ignore = [] - for d in dirs: - if d.startswith("_"): - self.logger.debug("Removing [%s] from load path.", d) - ignore.append(d) - for d in ignore: - dirs.remove(d) - else: - module_dirs.append(self.root_path) - # load path is only the root of the package hierarchy - component_root_path = os.path.abspath(os.path.join(self.root_path, os.pardir)) - self.logger.debug("Adding [%s] to Python load path.", component_root_path) - # needs to be at the beginning of the system path, otherwise import machinery tries to load application-internal modules - sys.path.insert(0, component_root_path) - try: - root_module = self._load_component(component_name, module_dirs) - return root_module - except BaseException: - msg = f"Could not load component [{component_name}]" - self.logger.exception(msg) - raise exceptions.SystemSetupError(msg) + root_modules = [] + for root_path in self.root_path: + component_name = io.basename(root_path) + self.logger.info("Loading component [%s] from [%s]", component_name, root_path) + module_dirs = [] + # search all paths within this directory for modules but exclude all directories starting with "_" + if self.recurse: + for dirpath, dirs, _ in os.walk(root_path): + module_dirs.append(dirpath) + ignore = [] + for d in dirs: + if d.startswith("_"): + self.logger.debug("Removing [%s] from load path.", d) + ignore.append(d) + for d in ignore: + dirs.remove(d) + else: + module_dirs.append(root_path) + # load path is only the root of the package hierarchy + component_root_path = os.path.abspath(os.path.join(root_path, os.pardir)) + self.logger.debug("Adding [%s] to Python load path.", component_root_path) + # needs to be at the beginning of the system path, otherwise import machinery tries to load application-internal modules + sys.path.insert(0, component_root_path) + try: + root_module = self._load_component(component_name, module_dirs, root_path) + root_modules.append(root_module) + except BaseException: + msg = f"Could not load component [{component_name}]" + self.logger.exception(msg) + raise + # raise exceptions.SystemSetupError(msg) + return root_modules diff --git a/esrally/utils/process.py b/esrally/utils/process.py index 6a283a723..8a1cd1e48 100644 --- a/esrally/utils/process.py +++ b/esrally/utils/process.py @@ -20,7 +20,7 @@ import shlex import subprocess import time -from typing import Callable, Dict, List +from typing import IO, Callable, Dict, List, Optional, Union import psutil @@ -74,7 +74,7 @@ def run_subprocess_with_logging( command_line: str, header: str = None, level: LogLevel = logging.INFO, - stdin: FileId = None, + stdin: Optional[Union[FileId, IO[bytes]]] = None, env: Dict[str, str] = None, detach: bool = False, ) -> int: From ab3bd182c51a452b547f9b1641fb5fea23717a01 Mon Sep 17 00:00:00 2001 From: favilo Date: Thu, 20 Jun 2024 18:46:10 -0700 Subject: [PATCH 02/26] Fixing tests --- esrally/mechanic/team.py | 2 +- esrally/utils/modules.py | 2 +- tests/mechanic/provisioner_test.py | 2 +- tests/mechanic/team_test.py | 28 +++++++++++++--------------- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/esrally/mechanic/team.py b/esrally/mechanic/team.py index aab92ccb4..f0f2b5d23 100644 --- a/esrally/mechanic/team.py +++ b/esrally/mechanic/team.py @@ -49,7 +49,7 @@ def list_cars(cfg: types.Config): console.println(tabulate.tabulate([[c.name, c.type, c.description] for c in cars], headers=["Name", "Type", "Description"])) -def load_car(repo, name, car_params=None): +def load_car(repo: str, name: Collection[str], car_params: Mapping=None) -> "Car": class Component: def __init__(self, root_path, entry_point): self.root_path = root_path diff --git a/esrally/utils/modules.py b/esrally/utils/modules.py index 9b5dd2aa3..555663954 100644 --- a/esrally/utils/modules.py +++ b/esrally/utils/modules.py @@ -77,7 +77,7 @@ def can_load(self): """ :return: True iff the component entry point could be found. """ - return self.root_path and all( + return all(self.root_path) and all( os.path.exists(os.path.join(root_path, "%s.py" % self.component_entry_point)) for root_path in self.root_path ) diff --git a/tests/mechanic/provisioner_test.py b/tests/mechanic/provisioner_test.py index fe6aeae4b..d4a8ac982 100644 --- a/tests/mechanic/provisioner_test.py +++ b/tests/mechanic/provisioner_test.py @@ -42,7 +42,7 @@ def null_apply_config(source_root_path, target_root_path, config_vars): installer = provisioner.ElasticsearchInstaller( car=team.Car( names="unit-test-car", - root_path=None, + root_path=None, # type: ignore config_paths=[HOME_DIR + "/.rally/benchmarks/teams/default/my-car"], variables={ "heap": "4g", diff --git a/tests/mechanic/team_test.py b/tests/mechanic/team_test.py index 32dcbedcf..6fb0fce04 100644 --- a/tests/mechanic/team_test.py +++ b/tests/mechanic/team_test.py @@ -48,17 +48,17 @@ def test_load_known_car(self): car = team.load_car(self.team_dir, ["default"], car_params={"data_paths": ["/mnt/disk0", "/mnt/disk1"]}) assert car.name == "default" assert car.config_paths == [os.path.join(current_dir, "data", "cars", "v1", "vanilla", "templates")] - assert car.root_path is None + assert car.root_path == [] assert car.variables == {"heap_size": "1g", "clean_command": "./gradlew clean", "data_paths": ["/mnt/disk0", "/mnt/disk1"]} - assert car.root_path is None + assert car.root_path == [] def test_load_car_with_mixin_single_config_base(self): car = team.load_car(self.team_dir, ["32gheap", "ea"]) assert car.name == "32gheap+ea" assert car.config_paths == [os.path.join(current_dir, "data", "cars", "v1", "vanilla", "templates")] - assert car.root_path is None + assert car.root_path == [] assert car.variables == {"heap_size": "32g", "clean_command": "./gradlew clean", "assertions": "true"} - assert car.root_path is None + assert car.root_path == [] def test_load_car_with_mixin_multiple_config_bases(self): car = team.load_car(self.team_dir, ["32gheap", "ea", "verbose"]) @@ -67,7 +67,7 @@ def test_load_car_with_mixin_multiple_config_bases(self): os.path.join(current_dir, "data", "cars", "v1", "vanilla", "templates"), os.path.join(current_dir, "data", "cars", "v1", "verbose_logging", "templates"), ] - assert car.root_path is None + assert car.root_path == [] assert car.variables == {"heap_size": "32g", "clean_command": "./gradlew clean", "verbose_logging": "true", "assertions": "true"} def test_load_car_with_install_hook(self): @@ -77,7 +77,7 @@ def test_load_car_with_install_hook(self): os.path.join(current_dir, "data", "cars", "v1", "vanilla", "templates"), os.path.join(current_dir, "data", "cars", "v1", "with_hook", "templates"), ] - assert car.root_path == os.path.join(current_dir, "data", "cars", "v1", "with_hook") + assert car.root_path == [os.path.join(current_dir, "data", "cars", "v1", "with_hook")] assert car.variables == {"heap_size": "1g", "clean_command": "./gradlew clean", "data_paths": ["/mnt/disk0", "/mnt/disk1"]} def test_load_car_with_multiple_bases_referring_same_install_hook(self): @@ -88,7 +88,7 @@ def test_load_car_with_multiple_bases_referring_same_install_hook(self): os.path.join(current_dir, "data", "cars", "v1", "with_hook", "templates"), os.path.join(current_dir, "data", "cars", "v1", "verbose_logging", "templates"), ] - assert car.root_path == os.path.join(current_dir, "data", "cars", "v1", "with_hook") + assert car.root_path == [os.path.join(current_dir, "data", "cars", "v1", "with_hook")] assert car.variables == {"heap_size": "16g", "clean_command": "./gradlew clean", "verbose_logging": "true"} def test_raises_error_on_unknown_car(self): @@ -112,12 +112,10 @@ def test_raises_error_on_missing_config_base(self): ): team.load_car(self.team_dir, ["missing_cfg_base"]) - def test_raises_error_if_more_than_one_different_install_hook(self): - with pytest.raises( - exceptions.SystemSetupError, - match=r"Invalid car: \['multi_hook'\]. Multiple bootstrap hooks are forbidden.", - ): - team.load_car(self.team_dir, ["multi_hook"]) + def test_doesnt_raise_error_if_more_than_one_different_install_hook(self): + car = team.load_car(self.team_dir, ["multi_hook"]) + assert isinstance(car.root_path, list) + assert len(car.root_path) == 2 class TestPluginLoader: @@ -229,7 +227,7 @@ def test_loads_module(self): hook = self.UnitTestHook() handler = team.BootstrapHookHandler(plugin, loader_class=self.UnitTestComponentLoader) - handler.loader.registration_function = hook + handler.loader.registration_function = [hook] handler.load() handler.invoke("post_install", variables={"increment": 4}) @@ -242,7 +240,7 @@ def test_cannot_register_for_unknown_phase(self): hook = self.UnitTestHook(phase="this_is_an_unknown_install_phase") handler = team.BootstrapHookHandler(plugin, loader_class=self.UnitTestComponentLoader) - handler.loader.registration_function = hook + handler.loader.registration_function = [hook] with pytest.raises(exceptions.SystemSetupError) as exc: handler.load() assert exc.value.args[0] == "Unknown bootstrap phase [this_is_an_unknown_install_phase]. Valid phases are: ['post_install']." From 8dd21536aa42e91f766aa55a7dfcc137f0a4838a Mon Sep 17 00:00:00 2001 From: favilo Date: Thu, 20 Jun 2024 18:56:00 -0700 Subject: [PATCH 03/26] Fixing lints --- esrally/mechanic/team.py | 10 ++++++---- esrally/utils/modules.py | 7 +++---- esrally/utils/process.py | 1 - tests/mechanic/provisioner_test.py | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/esrally/mechanic/team.py b/esrally/mechanic/team.py index f0f2b5d23..e12af2400 100644 --- a/esrally/mechanic/team.py +++ b/esrally/mechanic/team.py @@ -19,8 +19,8 @@ import logging import os from enum import Enum -from typing import Any, Collection, Mapping, Union from types import ModuleType +from typing import Any, Collection, Mapping, Union import tabulate @@ -49,7 +49,7 @@ def list_cars(cfg: types.Config): console.println(tabulate.tabulate([[c.name, c.type, c.description] for c in cars], headers=["Name", "Type", "Description"])) -def load_car(repo: str, name: Collection[str], car_params: Mapping=None) -> "Car": +def load_car(repo: str, name: Collection[str], car_params: Mapping = None) -> "Car": class Component: def __init__(self, root_path, entry_point): self.root_path = root_path @@ -240,7 +240,7 @@ class Car: def __init__( self, names: Collection[str], - root_path: Union[str, Collection[str]], + root_path: Union[None, str, Collection[str]], config_paths: Collection[str], variables: Mapping[str, Any] = None, ): @@ -259,7 +259,9 @@ def __init__( else: self.names = names - if isinstance(root_path, str): + if root_path is None: + self.root_path = [] + elif isinstance(root_path, str): self.root_path = [root_path] else: self.root_path = root_path diff --git a/esrally/utils/modules.py b/esrally/utils/modules.py index 555663954..b0396becb 100644 --- a/esrally/utils/modules.py +++ b/esrally/utils/modules.py @@ -19,8 +19,8 @@ import logging import os import sys -from typing import Collection, Union from types import ModuleType +from typing import Collection, Union from esrally import exceptions from esrally.utils import io @@ -65,7 +65,7 @@ def _load_component(self, component_name: str, module_dirs: Collection[str], roo # precondition: A module with this name has to exist provided that the caller has called #can_load() before. root_module_name = "%s.%s" % (component_name, self.component_entry_point) for name, p in self._modules(module_dirs, component_name, root_path): - self.logger.debug(f"Loading module [{name}]: {p}") + self.logger.debug("Loading module [%s]: %s", name, p) spec = importlib.util.spec_from_file_location(name, p) m = importlib.util.module_from_spec(spec) spec.loader.exec_module(m) @@ -118,6 +118,5 @@ def load(self) -> Collection[ModuleType]: except BaseException: msg = f"Could not load component [{component_name}]" self.logger.exception(msg) - raise - # raise exceptions.SystemSetupError(msg) + raise exceptions.SystemSetupError(msg) return root_modules diff --git a/esrally/utils/process.py b/esrally/utils/process.py index 8a1cd1e48..c26c4c0f2 100644 --- a/esrally/utils/process.py +++ b/esrally/utils/process.py @@ -142,7 +142,6 @@ def run_subprocess_with_logging_and_output( if header is not None: logger.info(header) - # pylint: disable=subprocess-popen-preexec-fn completed = subprocess.run( command_line_args, stdout=subprocess.PIPE, diff --git a/tests/mechanic/provisioner_test.py b/tests/mechanic/provisioner_test.py index d4a8ac982..fe6aeae4b 100644 --- a/tests/mechanic/provisioner_test.py +++ b/tests/mechanic/provisioner_test.py @@ -42,7 +42,7 @@ def null_apply_config(source_root_path, target_root_path, config_vars): installer = provisioner.ElasticsearchInstaller( car=team.Car( names="unit-test-car", - root_path=None, # type: ignore + root_path=None, config_paths=[HOME_DIR + "/.rally/benchmarks/teams/default/my-car"], variables={ "heap": "4g", From 6b9791b3a2f32a2ab56910ee7825c1572ae810e2 Mon Sep 17 00:00:00 2001 From: favilo Date: Fri, 21 Jun 2024 10:16:06 -0700 Subject: [PATCH 04/26] Fix rally-tracks-compat nox test --- esrally/track/loader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/esrally/track/loader.py b/esrally/track/loader.py index 56a299d13..6c6f947e1 100644 --- a/esrally/track/loader.py +++ b/esrally/track/loader.py @@ -1226,10 +1226,11 @@ def load(self): # get dependent libraries installed in a prior step. ensure dir exists to make sure loading works correctly. os.makedirs(paths.libs(), exist_ok=True) sys.path.insert(0, paths.libs()) - root_module = self.loader.load() + root_modules = self.loader.load() try: # every module needs to have a register() method - root_module.register(self) + for module in root_modules: + module.register(self) except BaseException: msg = "Could not register track plugin at [%s]" % self.loader.root_path logging.getLogger(__name__).exception(msg) From 5d65ef57cbc47920b4155d247a15e204c61c200e Mon Sep 17 00:00:00 2001 From: favilo Date: Fri, 28 Jun 2024 11:16:17 -0700 Subject: [PATCH 05/26] Fix issue pointed out by @AI-IshanBhatt Also revert naked `raise` back to `SystemSetupError`. I'd only done that for debugging purposes --- esrally/mechanic/team.py | 6 +----- esrally/utils/modules.py | 1 - 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/esrally/mechanic/team.py b/esrally/mechanic/team.py index e12af2400..5309c5870 100644 --- a/esrally/mechanic/team.py +++ b/esrally/mechanic/team.py @@ -71,9 +71,6 @@ def __init__(self, root_path, entry_point): if BootstrapHookHandler(Component(root_path=p, entry_point=Car.entry_point)).can_load(): if p not in root_paths: root_paths.append(p) - # multiple cars are based on the same hook - # elif root_paths != p: - # raise exceptions.SystemSetupError(f"Invalid car: {name}. Multiple bootstrap hooks are forbidden.") all_config_base_vars.update(descriptor.config_base_variables) all_car_vars.update(descriptor.variables) @@ -518,8 +515,7 @@ def load(self): except BaseException: msg = f"Could not load bootstrap hooks in [{self.loader.root_path}]" self.logger.exception(msg) - raise - # raise exceptions.SystemSetupError(msg) + raise exceptions.SystemSetupError(msg) def register(self, phase, hook): self.logger.info("Registering bootstrap hook [%s] for phase [%s] in component [%s]", hook.__name__, phase, self.component.name) diff --git a/esrally/utils/modules.py b/esrally/utils/modules.py index b0396becb..a6bf537ee 100644 --- a/esrally/utils/modules.py +++ b/esrally/utils/modules.py @@ -111,7 +111,6 @@ def load(self) -> Collection[ModuleType]: component_root_path = os.path.abspath(os.path.join(root_path, os.pardir)) self.logger.debug("Adding [%s] to Python load path.", component_root_path) # needs to be at the beginning of the system path, otherwise import machinery tries to load application-internal modules - sys.path.insert(0, component_root_path) try: root_module = self._load_component(component_name, module_dirs, root_path) root_modules.append(root_module) From 72a6009df22676a372c59b39bc984b4614f1d017 Mon Sep 17 00:00:00 2001 From: favilo Date: Fri, 28 Jun 2024 11:25:48 -0700 Subject: [PATCH 06/26] Fix mypy `type-arg` error I introduced --- esrally/utils/modules.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/esrally/utils/modules.py b/esrally/utils/modules.py index a6bf537ee..0fa2b45a8 100644 --- a/esrally/utils/modules.py +++ b/esrally/utils/modules.py @@ -67,6 +67,8 @@ def _load_component(self, component_name: str, module_dirs: Collection[str], roo for name, p in self._modules(module_dirs, component_name, root_path): self.logger.debug("Loading module [%s]: %s", name, p) spec = importlib.util.spec_from_file_location(name, p) + if spec is None: + raise exceptions.SystemSetupError(f"Could not load module [{name}]") m = importlib.util.module_from_spec(spec) spec.loader.exec_module(m) if name == root_module_name: From f61259e306e88b28ebb72e4d868f26b0f0a18397 Mon Sep 17 00:00:00 2001 From: favilo Date: Fri, 28 Jun 2024 12:07:36 -0700 Subject: [PATCH 07/26] Add sys.path.insert back It is the fastest way to get this working without writing my own Importer --- esrally/utils/modules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/esrally/utils/modules.py b/esrally/utils/modules.py index 0fa2b45a8..8a228d1bc 100644 --- a/esrally/utils/modules.py +++ b/esrally/utils/modules.py @@ -113,6 +113,7 @@ def load(self) -> Collection[ModuleType]: component_root_path = os.path.abspath(os.path.join(root_path, os.pardir)) self.logger.debug("Adding [%s] to Python load path.", component_root_path) # needs to be at the beginning of the system path, otherwise import machinery tries to load application-internal modules + sys.path.insert(0, component_root_path) try: root_module = self._load_component(component_name, module_dirs, root_path) root_modules.append(root_module) From a2f80732139642c2d3e6890328b452cc1c0a3b94 Mon Sep 17 00:00:00 2001 From: Grzegorz Banasiak Date: Mon, 1 Jul 2024 10:31:32 +0200 Subject: [PATCH 08/26] Introduce mypy overrides --- .pre-commit-config.yaml | 6 +++ esrally/client/asynchronous.py | 4 +- esrally/mechanic/team.py | 16 +++--- esrally/utils/modules.py | 6 +-- pyproject.toml | 53 ++++++++++++++++++-- tests/client/factory_test.py | 16 ++++-- tests/driver/driver_test.py | 40 +++++++++++++-- tests/driver/runner_test.py | 86 ++++++++++++++++++++++++++------- tests/mechanic/launcher_test.py | 2 +- tests/metrics_test.py | 15 ++++-- tests/telemetry_test.py | 53 ++++++++++++++++++-- tests/utils/collections_test.py | 2 +- tests/utils/net_test.py | 7 ++- tests/utils/versions_test.py | 2 +- 14 files changed, 251 insertions(+), 57 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 989f09e70..43cbdfeee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,13 @@ repos: rev: v1.6.1 hooks: - id: mypy + additional_dependencies: [ + "elasticsearch[async]==8.6.1", + "elastic-transport==8.4.1", + "types-tabulate", + ] args: [ + "--ignore-missing-imports", "--config", "pyproject.toml" ] diff --git a/esrally/client/asynchronous.py b/esrally/client/asynchronous.py index 2d869390d..9c32efc44 100644 --- a/esrally/client/asynchronous.py +++ b/esrally/client/asynchronous.py @@ -77,7 +77,7 @@ async def send(self, conn: "Connection") -> "ClientResponse": self.response = self.response_class( self.method, self.original_url, - writer=self._writer, + writer=self._writer, # type: ignore[arg-type] # TODO remove this ignore when introducing type hints continue100=self._continue, timer=self._timer, request_info=self.request_info, @@ -223,7 +223,7 @@ def __init__(self, config): self._loop = None self.client_id = None self.trace_configs = None - self.enable_cleanup_closed = None + self.enable_cleanup_closed = False self._static_responses = None self._request_class = aiohttp.ClientRequest self._response_class = aiohttp.ClientResponse diff --git a/esrally/mechanic/team.py b/esrally/mechanic/team.py index 5309c5870..6023e4698 100644 --- a/esrally/mechanic/team.py +++ b/esrally/mechanic/team.py @@ -20,7 +20,7 @@ import os from enum import Enum from types import ModuleType -from typing import Any, Collection, Mapping, Union +from typing import Any, Collection, Mapping, Optional, Union import tabulate @@ -49,7 +49,7 @@ def list_cars(cfg: types.Config): console.println(tabulate.tabulate([[c.name, c.type, c.description] for c in cars], headers=["Name", "Type", "Description"])) -def load_car(repo: str, name: Collection[str], car_params: Mapping = None) -> "Car": +def load_car(repo: str, name: Collection[str], car_params: Optional[Mapping] = None) -> "Car": class Component: def __init__(self, root_path, entry_point): self.root_path = root_path @@ -165,7 +165,7 @@ def load_car(self, name, car_params=None): config = self._config_loader(car_config_file) root_paths = [] config_paths = [] - config_base_vars = {} + config_base_vars: Mapping[str, Any] = {} description = self._value(config, ["meta", "description"], default="") car_type = self._value(config, ["meta", "type"], default="car") config_bases = self._value(config, ["config", "base"], default="").split(",") @@ -192,7 +192,7 @@ def load_car(self, name, car_params=None): def _config_loader(self, file_name): config = configparser.ConfigParser(interpolation=configparser.ExtendedInterpolation()) # Do not modify the case of option keys but read them as is - config.optionxform = lambda option: option + config.optionxform = lambda optionstr: optionstr # type: ignore[method-assign] config.read(file_name) return config @@ -239,7 +239,7 @@ def __init__( names: Collection[str], root_path: Union[None, str, Collection[str]], config_paths: Collection[str], - variables: Mapping[str, Any] = None, + variables: Optional[Mapping[str, Any]] = None, ): """ Creates new settings for a benchmark candidate. @@ -252,12 +252,12 @@ def __init__( if variables is None: variables = {} if isinstance(names, str): - self.names = [names] + self.names: Collection[str] = [names] else: self.names = names if root_path is None: - self.root_path = [] + self.root_path: Collection[str] = [] elif isinstance(root_path, str): self.root_path = [root_path] else: @@ -393,7 +393,7 @@ def load_plugin(self, name, config_names, plugin_params=None): config = configparser.ConfigParser(interpolation=configparser.ExtendedInterpolation()) # Do not modify the case of option keys but read them as is - config.optionxform = lambda option: option + config.optionxform = lambda optionstr: optionstr # type: ignore[method-assign] config.read(config_file) if "config" in config and "base" in config["config"]: config_bases = config["config"]["base"].split(",") diff --git a/esrally/utils/modules.py b/esrally/utils/modules.py index 8a228d1bc..3c1f072bf 100644 --- a/esrally/utils/modules.py +++ b/esrally/utils/modules.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -import importlib.machinery +import importlib.util import logging import os import sys @@ -45,7 +45,7 @@ def __init__(self, root_path: Union[str, Collection[str]], component_entry_point ``root_path``. :param recurse: Search recursively for modules but ignore modules starting with "_" (Default: ``True``). """ - self.root_path: Collection[str] = root_path if isinstance(root_path, list) else [root_path] + self.root_path: Collection[str] = root_path if isinstance(root_path, list) else [str(root_path)] self.component_entry_point = component_entry_point self.recurse = recurse self.logger = logging.getLogger(__name__) @@ -67,7 +67,7 @@ def _load_component(self, component_name: str, module_dirs: Collection[str], roo for name, p in self._modules(module_dirs, component_name, root_path): self.logger.debug("Loading module [%s]: %s", name, p) spec = importlib.util.spec_from_file_location(name, p) - if spec is None: + if spec is None or spec.loader is None: raise exceptions.SystemSetupError(f"Could not load module [{name}]") m = importlib.util.module_from_spec(spec) spec.loader.exec_module(m) diff --git a/pyproject.toml b/pyproject.toml index 1331874bd..627e9ecd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,12 +143,25 @@ junit_logging = "all" asyncio_mode = "strict" xfail_strict = true -# With rare exceptions, Rally does not use type hints. The intention of the -# following largely reduced mypy configuration scope is verification of argument -# types in config.Config methods while introducing configuration properties -# (props). The error we are after here is "arg-type". +# With exceptions specified in mypy override section, Rally does not use type +# hints (they were a novelty when Rally came to be). Hints are being slowly and +# opportunistically introduced whenever we revisit a group of modules. +# +# The intention of the following largely reduced global config scope is +# verification of argument types in config.Config methods while introducing +# configuration properties (props). The intention of "disable_error_code" option +# is to keep "arg-type" error code, while disabling other error codes. +# Ref: https://github.com/elastic/rally/pull/1798 [tool.mypy] -python_version = 3.8 +python_version = "3.8" +# subset of "strict", kept at global config level as some of the options are +# supported only at this level +# https://mypy.readthedocs.io/en/stable/existing_code.html#introduce-stricter-options +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true +strict_equality = true +extra_checks = true check_untyped_defs = true disable_error_code = [ "assignment", @@ -168,3 +181,33 @@ disable_error_code = [ "union-attr", "var-annotated", ] +files = [ + "esrally/", + "it/", + "tests/", +] + +[[tool.mypy.overrides]] +module = [ + "esrally.mechanic.team", + "esrally.utils.modules", +] +# this should be a copy of disabled_error_code from above +enable_error_code = [ + "assignment", + "attr-defined", + "call-arg", + "call-overload", + "dict-item", + "import-not-found", + "import-untyped", + "index", + "list-item", + "misc", + "name-defined", + "operator", + "str-bytes-safe", + "syntax", + "union-attr", + "var-annotated", +] diff --git a/tests/client/factory_test.py b/tests/client/factory_test.py index 791423365..fd966d30f 100644 --- a/tests/client/factory_test.py +++ b/tests/client/factory_test.py @@ -29,7 +29,7 @@ import pytest import trustme import urllib3.exceptions -from elastic_transport import ApiResponseMeta +from elastic_transport import ApiResponseMeta, HttpHeaders, NodeConfig from pytest_httpserver import HTTPServer from esrally import client, doc_link, exceptions @@ -38,7 +38,17 @@ def _api_error(status, message): - return elasticsearch.ApiError(message, ApiResponseMeta(status=status, http_version="1.1", headers={}, duration=0.0, node=None), None) + return elasticsearch.ApiError( + message, + ApiResponseMeta( + status=status, + http_version="1.1", + headers=HttpHeaders(), + duration=0.0, + node=NodeConfig(scheme="https", host="localhost", port=9200), + ), + None, + ) class TestEsClientFactory: @@ -518,7 +528,7 @@ def test_connection_ssl_error(self, es): def test_connection_protocol_error(self, es): es.cluster.health.side_effect = elasticsearch.ConnectionError( message="N/A", - errors=[urllib3.exceptions.ProtocolError("Connection aborted.")], + errors=[urllib3.exceptions.ProtocolError("Connection aborted.")], # type: ignore[arg-type] ) with pytest.raises( exceptions.SystemSetupError, diff --git a/tests/driver/driver_test.py b/tests/driver/driver_test.py index e9360d849..10f753793 100644 --- a/tests/driver/driver_test.py +++ b/tests/driver/driver_test.py @@ -1894,7 +1894,13 @@ async def test_execute_single_with_connection_error_always_aborts(self, on_error async def test_execute_single_with_http_400_aborts_when_specified(self): es = None params = None - error_meta = elastic_transport.ApiResponseMeta(status=404, http_version="1.1", headers={}, duration=0.0, node=None) + error_meta = elastic_transport.ApiResponseMeta( + status=404, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0.0, + node=elastic_transport.NodeConfig(scheme="http", host="localhost", port=9200), + ) runner = mock.AsyncMock( side_effect=elasticsearch.NotFoundError(message="not found", meta=error_meta, body="the requested document could not be found") ) @@ -1912,7 +1918,13 @@ async def test_execute_single_with_http_400_with_empty_raw_response_body(self): params = None empty_body = io.BytesIO(b"") str_literal_empty_body = str(empty_body) - error_meta = elastic_transport.ApiResponseMeta(status=413, http_version="1.1", headers={}, duration=0.0, node=None) + error_meta = elastic_transport.ApiResponseMeta( + status=413, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0.0, + node=elastic_transport.NodeConfig(scheme="http", host="localhost", port=9200), + ) runner = mock.AsyncMock(side_effect=elasticsearch.ApiError(message=str_literal_empty_body, meta=error_meta, body=empty_body)) with pytest.raises(exceptions.RallyAssertionError) as exc: @@ -1925,7 +1937,13 @@ async def test_execute_single_with_http_400_with_raw_response_body(self): params = None body = io.BytesIO(b"Huge error") str_literal = str(body) - error_meta = elastic_transport.ApiResponseMeta(status=499, http_version="1.1", headers={}, duration=0.0, node=None) + error_meta = elastic_transport.ApiResponseMeta( + status=499, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0.0, + node=elastic_transport.NodeConfig(scheme="http", host="localhost", port=9200), + ) runner = mock.AsyncMock(side_effect=elasticsearch.ApiError(message=str_literal, meta=error_meta, body=body)) with pytest.raises(exceptions.RallyAssertionError) as exc: @@ -1936,7 +1954,13 @@ async def test_execute_single_with_http_400_with_raw_response_body(self): async def test_execute_single_with_http_400(self): es = None params = None - error_meta = elastic_transport.ApiResponseMeta(status=404, http_version="1.1", headers={}, duration=0.0, node=None) + error_meta = elastic_transport.ApiResponseMeta( + status=404, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0.0, + node=elastic_transport.NodeConfig(scheme="http", host="localhost", port=9200), + ) runner = mock.AsyncMock( side_effect=elasticsearch.NotFoundError(message="not found", meta=error_meta, body="the requested document could not be found") ) @@ -1956,7 +1980,13 @@ async def test_execute_single_with_http_400(self): async def test_execute_single_with_http_413(self): es = None params = None - error_meta = elastic_transport.ApiResponseMeta(status=413, http_version="1.1", headers={}, duration=0.0, node=None) + error_meta = elastic_transport.ApiResponseMeta( + status=413, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0.0, + node=elastic_transport.NodeConfig(scheme="http", host="localhost", port=9200), + ) runner = mock.AsyncMock(side_effect=elasticsearch.NotFoundError(message="", meta=error_meta, body="")) ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params, on_error="continue") diff --git a/tests/driver/runner_test.py b/tests/driver/runner_test.py index 910c1de95..561404752 100644 --- a/tests/driver/runner_test.py +++ b/tests/driver/runner_test.py @@ -3905,8 +3905,14 @@ async def test_create_ml_datafeed(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_create_ml_datafeed_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.put_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.put_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() datafeed_id = "some-data-feed" body = {"job_id": "total-requests", "indices": ["server-metrics"]} @@ -3935,8 +3941,16 @@ async def test_delete_ml_datafeed(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_delete_ml_datafeed_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.delete_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.delete_datafeed = mock.AsyncMock( + side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request") + ) es.perform_request = mock.AsyncMock() datafeed_id = "some-data-feed" @@ -3969,8 +3983,14 @@ async def test_start_ml_datafeed_with_body(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_start_ml_datafeed_with_body_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.start_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.start_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() body = {"end": "now"} params = {"datafeed-id": "some-data-feed", "body": body} @@ -4018,8 +4038,14 @@ async def test_stop_ml_datafeed(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_stop_ml_datafeed_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.stop_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.stop_datafeed = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() params = { @@ -4070,8 +4096,14 @@ async def test_create_ml_job(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_create_ml_job_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.put_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.put_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() body = { @@ -4113,8 +4145,14 @@ async def test_delete_ml_job(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_delete_ml_job_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.delete_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.delete_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() job_id = "an-ml-job" @@ -4145,8 +4183,14 @@ async def test_open_ml_job(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_open_ml_job_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.open_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.open_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() job_id = "an-ml-job" @@ -4177,8 +4221,14 @@ async def test_close_ml_job(self, es): @mock.patch("elasticsearch.Elasticsearch") @pytest.mark.asyncio async def test_close_ml_job_fallback(self, es): - error_meta = elastic_transport.ApiResponseMeta(status=400, http_version="1.1", headers=None, duration=0, node=None) - es.ml.close_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message=400, meta=error_meta, body="Bad Request")) + error_meta = elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ) + es.ml.close_job = mock.AsyncMock(side_effect=elasticsearch.BadRequestError(message="400", meta=error_meta, body="Bad Request")) es.perform_request = mock.AsyncMock() params = { @@ -7399,7 +7449,7 @@ async def test_is_transparent_on_success_when_no_retries(self): @pytest.mark.asyncio async def test_is_transparent_on_exception_when_no_retries(self): - delegate = mock.AsyncMock(side_effect=elasticsearch.ConnectionError("N/A", "no route to host")) + delegate = mock.AsyncMock(side_effect=elasticsearch.ConnectionError(message="no route to host")) es = None params = { # no retries @@ -7537,7 +7587,7 @@ async def test_retries_mixed_timeout_and_application_errors(self): @pytest.mark.asyncio async def test_does_not_retry_on_timeout_if_not_wanted(self): - delegate = mock.AsyncMock(side_effect=elasticsearch.ConnectionTimeout(408, "timed out")) + delegate = mock.AsyncMock(side_effect=elasticsearch.ConnectionTimeout(message="timed out")) es = None params = {"retries": 3, "retry-wait-period": 0.01, "retry-on-timeout": False, "retry-on-error": True} retrier = runner.Retry(delegate) diff --git a/tests/mechanic/launcher_test.py b/tests/mechanic/launcher_test.py index c6758880e..9bdb84c6c 100644 --- a/tests/mechanic/launcher_test.py +++ b/tests/mechanic/launcher_test.py @@ -72,7 +72,7 @@ def __init__(self, client_options): def info(self): if self.client_options.get("raise-error-on-info", False): - raise elasticsearch.TransportError(401, "Unauthorized") + raise elasticsearch.TransportError(message="Unauthorized") return self._info def search(self, *args, **kwargs): diff --git a/tests/metrics_test.py b/tests/metrics_test.py index b299172c9..adf9b1f18 100644 --- a/tests/metrics_test.py +++ b/tests/metrics_test.py @@ -234,7 +234,12 @@ def logging_statements(self, retries): return logging_statements def raise_error(self): - err = elasticsearch.exceptions.ApiError("unit-test", meta=TestEsClient.ApiResponseMeta(status=self.status_code), body={}) + err = elasticsearch.exceptions.ApiError( + "unit-test", + # TODO remove this ignore when introducing type hints + meta=TestEsClient.ApiResponseMeta(status=self.status_code), # type: ignore[arg-type] + body={}, + ) raise err class BulkIndexError: @@ -321,7 +326,9 @@ def raise_error(self): def test_raises_sytem_setup_error_on_authentication_problems(self): def raise_authentication_error(): - raise elasticsearch.exceptions.AuthenticationException(meta=None, body=None, message="unit-test") + raise elasticsearch.exceptions.AuthenticationException( + meta=None, body=None, message="unit-test" # type: ignore[arg-type] # TODO remove this ignore when introducing type hints + ) client = metrics.EsClient(self.ClientMock([{"host": "127.0.0.1", "port": "9243"}])) @@ -334,7 +341,9 @@ def raise_authentication_error(): def test_raises_sytem_setup_error_on_authorization_problems(self): def raise_authorization_error(): - raise elasticsearch.exceptions.AuthorizationException(meta=None, body=None, message="unit-test") + raise elasticsearch.exceptions.AuthorizationException( + meta=None, body=None, message="unit-test" # type: ignore[arg-type] # TODO remove this ignore when introducing type hints + ) client = metrics.EsClient(self.ClientMock([{"host": "127.0.0.1", "port": "9243"}])) diff --git a/tests/telemetry_test.py b/tests/telemetry_test.py index c14b79c26..001a42e7c 100644 --- a/tests/telemetry_test.py +++ b/tests/telemetry_test.py @@ -24,6 +24,7 @@ from unittest import mock from unittest.mock import call +import elastic_transport import elasticsearch import pytest @@ -275,7 +276,8 @@ class ApiResponseMeta: def __call__(self, status=None, body=None, message=None): return elasticsearch.ApiError( - meta=self.ApiResponseMeta(status=status), + # TODO remove this ignore when introducing type hints + meta=self.ApiResponseMeta(status=status), # type: ignore[arg-type] body=body, message=message, ) @@ -1751,7 +1753,18 @@ def test_no_metrics_if_no_searchable_snapshots_stats(self, metrics_store_put_doc metrics_store = metrics.EsMetricsStore(cfg) client = Client( transport_client=TransportClient( - force_error=True, error=elasticsearch.NotFoundError("", "", {"error": {"reason": "No searchable snapshots indices found"}}) + force_error=True, + error=elasticsearch.NotFoundError( + message="", + meta=elastic_transport.ApiResponseMeta( + status=404, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ), + body={"error": {"reason": "No searchable snapshots indices found"}}, + ), ) ) recorder = telemetry.SearchableSnapshotsStatsRecorder( @@ -4885,7 +4898,17 @@ def test_uses_indices_param_if_specified_instead_of_data_stream_names(self, es): def test_error_on_retrieval_does_not_store_metrics(self, es, metrics_store_cluster_level, caplog): cfg = create_config() metrics_store = metrics.EsMetricsStore(cfg) - es.indices.disk_usage.side_effect = elasticsearch.RequestError(message="error", meta=None, body=None) + es.indices.disk_usage.side_effect = elasticsearch.RequestError( + message="error", + meta=elastic_transport.ApiResponseMeta( + status=400, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ), + body=None, + ) device = telemetry.DiskUsageStats({}, es, metrics_store, index_names=["foo"], data_stream_names=[]) t = telemetry.Telemetry(enabled_devices=[device.command], devices=[device]) t.on_benchmark_start() @@ -4916,7 +4939,17 @@ def test_no_indices_fails(self, es, metrics_store_cluster_level, caplog): def test_missing_all_fails(self, es, metrics_store_cluster_level, caplog): cfg = create_config() metrics_store = metrics.EsMetricsStore(cfg) - es.indices.disk_usage.side_effect = elasticsearch.NotFoundError(message="error", meta=None, body=None) + es.indices.disk_usage.side_effect = elasticsearch.NotFoundError( + message="error", + meta=elastic_transport.ApiResponseMeta( + status=404, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ), + body=None, + ) device = telemetry.DiskUsageStats({}, es, metrics_store, index_names=["foo", "bar"], data_stream_names=[]) t = telemetry.Telemetry(enabled_devices=[device.command], devices=[device]) t.on_benchmark_start() @@ -4933,7 +4966,17 @@ def test_missing_all_fails(self, es, metrics_store_cluster_level, caplog): def test_some_mising_succeeds(self, es, metrics_store_cluster_level, caplog): cfg = create_config() metrics_store = metrics.EsMetricsStore(cfg) - not_found_response = elasticsearch.NotFoundError(message="error", meta=None, body=None) + not_found_response = elasticsearch.NotFoundError( + message="error", + meta=elastic_transport.ApiResponseMeta( + status=404, + http_version="1.1", + headers=elastic_transport.HttpHeaders(), + duration=0, + node=elastic_transport.NodeConfig(scheme="https", host="localhost", port=9200), + ), + body=None, + ) successful_response = { "_shards": {"failed": 0}, "foo": { diff --git a/tests/utils/collections_test.py b/tests/utils/collections_test.py index e262b52f8..0c451bec6 100644 --- a/tests/utils/collections_test.py +++ b/tests/utils/collections_test.py @@ -18,7 +18,7 @@ import random from typing import Any, Mapping -import pytest # type: ignore +import pytest from esrally.utils import collections diff --git a/tests/utils/net_test.py b/tests/utils/net_test.py index 4fa912c9d..733acbbbd 100644 --- a/tests/utils/net_test.py +++ b/tests/utils/net_test.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import random +from typing import Mapping, Union from unittest import mock import pytest @@ -135,7 +136,8 @@ def raise_error(seconds): def test_download_http_retry_incomplete_read_retry_failure(httpserver, tmp_path): data = b"x" * 10 - short_resp = Response(headers={"Content-Length": 100, "foo": "bar"}) + headers: Mapping[str, Union[str, int]] = {"Content-Length": 100, "foo": "bar"} + short_resp = Response(headers=headers) short_resp.automatically_set_content_length = False short_resp.set_data(data) @@ -154,7 +156,8 @@ def sleep(seconds): def test_download_http_retry_incomplete_read_retry_success(httpserver, tmp_path): data = b"x" * 10 - short_resp = Response(headers={"Content-Length": 100, "foo": "bar"}) + headers: Mapping[str, Union[str, int]] = {"Content-Length": 100, "foo": "bar"} + short_resp = Response(headers=headers) short_resp.automatically_set_content_length = False short_resp.set_data(data) diff --git a/tests/utils/versions_test.py b/tests/utils/versions_test.py index 845ceb311..f03cb0329 100644 --- a/tests/utils/versions_test.py +++ b/tests/utils/versions_test.py @@ -18,7 +18,7 @@ import random import re -import pytest # type: ignore +import pytest from esrally import exceptions from esrally.utils import versions From 357726507540adbda144a83d7ecc58c25d4d092d Mon Sep 17 00:00:00 2001 From: favilo Date: Mon, 1 Jul 2024 10:36:49 -0700 Subject: [PATCH 09/26] Add local .venv interpretter to mypy configuration for ease of local configuration --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 627e9ecd3..3827d4de8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -153,6 +153,7 @@ xfail_strict = true # is to keep "arg-type" error code, while disabling other error codes. # Ref: https://github.com/elastic/rally/pull/1798 [tool.mypy] +python_executable = ".venv/bin/python" python_version = "3.8" # subset of "strict", kept at global config level as some of the options are # supported only at this level From 60b1d5fdd4ddaf66dc02f7d55777a3217ea342db Mon Sep 17 00:00:00 2001 From: favilo Date: Mon, 1 Jul 2024 10:40:20 -0700 Subject: [PATCH 10/26] Add comment explaining `spec_from_file_location` and friends --- esrally/utils/modules.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/esrally/utils/modules.py b/esrally/utils/modules.py index 3c1f072bf..a7b0bc83a 100644 --- a/esrally/utils/modules.py +++ b/esrally/utils/modules.py @@ -66,6 +66,8 @@ def _load_component(self, component_name: str, module_dirs: Collection[str], roo root_module_name = "%s.%s" % (component_name, self.component_entry_point) for name, p in self._modules(module_dirs, component_name, root_path): self.logger.debug("Loading module [%s]: %s", name, p) + # Use the util methods instead of `importlib.import_module` to allow for more fine-grained control over the import process. + # in particular, we want to be able to import multiple modules that use the same name, but are from different directories. spec = importlib.util.spec_from_file_location(name, p) if spec is None or spec.loader is None: raise exceptions.SystemSetupError(f"Could not load module [{name}]") From 8a21d640cd5ced21ac67db0487e3a7c5c6676a9e Mon Sep 17 00:00:00 2001 From: favilo Date: Mon, 1 Jul 2024 10:46:09 -0700 Subject: [PATCH 11/26] Remove executable for CI. Dang --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3827d4de8..627e9ecd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -153,7 +153,6 @@ xfail_strict = true # is to keep "arg-type" error code, while disabling other error codes. # Ref: https://github.com/elastic/rally/pull/1798 [tool.mypy] -python_executable = ".venv/bin/python" python_version = "3.8" # subset of "strict", kept at global config level as some of the options are # supported only at this level From 2cafe2986b041b1e8a775d03235c690097e3d21d Mon Sep 17 00:00:00 2001 From: Grzegorz Banasiak Date: Tue, 2 Jul 2024 08:33:58 +0200 Subject: [PATCH 12/26] Pin types-tabulate --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 43cbdfeee..8649788f6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,7 @@ repos: additional_dependencies: [ "elasticsearch[async]==8.6.1", "elastic-transport==8.4.1", - "types-tabulate", + "types-tabulate==0.8.9", ] args: [ "--ignore-missing-imports", From 3e03eeab0025f712f0923e3d88743457b5d4266c Mon Sep 17 00:00:00 2001 From: favilo Date: Tue, 2 Jul 2024 09:15:02 -0700 Subject: [PATCH 13/26] I went a little buckwild while fixing type annotations. I turned on a stricter mode than what we have configured in pyproject.toml --- esrally/mechanic/launcher.py | 2 +- esrally/mechanic/team.py | 159 +++++++++++++++++---------- esrally/utils/io.py | 204 +++++++++++++++++++++-------------- esrally/utils/modules.py | 6 +- esrally/utils/process.py | 19 ++-- pyproject.toml | 20 ++-- 6 files changed, 249 insertions(+), 161 deletions(-) diff --git a/esrally/mechanic/launcher.py b/esrally/mechanic/launcher.py index 18d21f90b..b23a0682f 100644 --- a/esrally/mechanic/launcher.py +++ b/esrally/mechanic/launcher.py @@ -242,7 +242,7 @@ def stop(self, nodes, metrics_store): stop_watch.start() try: es.terminate() - es.wait(10.0) + es.wait(10) stopped_nodes.append(node) except psutil.NoSuchProcess: self.logger.warning("No process found with PID [%s] for node [%s].", es.pid, node_name) diff --git a/esrally/mechanic/team.py b/esrally/mechanic/team.py index 6023e4698..9b1fa678c 100644 --- a/esrally/mechanic/team.py +++ b/esrally/mechanic/team.py @@ -20,7 +20,18 @@ import os from enum import Enum from types import ModuleType -from typing import Any, Collection, Mapping, Optional, Union +from typing import ( + Any, + Callable, + Collection, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Tuple, + Union, +) import tabulate @@ -30,14 +41,14 @@ TEAM_FORMAT_VERSION = 1 -def _path_for(team_root_path, team_member_type): +def _path_for(team_root_path: str, team_member_type: str) -> str: root_path = os.path.join(team_root_path, team_member_type, f"v{TEAM_FORMAT_VERSION}") if not os.path.exists(root_path): raise exceptions.SystemSetupError(f"Path {root_path} for {team_member_type} does not exist.") return root_path -def list_cars(cfg: types.Config): +def list_cars(cfg: types.Config) -> None: loader = CarLoader(team_path(cfg)) cars = [] for name in loader.car_names(): @@ -51,15 +62,15 @@ def list_cars(cfg: types.Config): def load_car(repo: str, name: Collection[str], car_params: Optional[Mapping] = None) -> "Car": class Component: - def __init__(self, root_path, entry_point): + def __init__(self, root_path: str, entry_point: str): self.root_path = root_path self.entry_point = entry_point root_paths = [] # preserve order as we append to existing config files later during provisioning. all_config_paths = [] - all_config_base_vars = {} - all_car_vars = {} + all_config_base_vars: MutableMapping[str, str] = {} + all_car_vars: MutableMapping[str, str] = {} for n in name: descriptor = CarLoader(repo).load_car(n, car_params) @@ -76,7 +87,7 @@ def __init__(self, root_path, entry_point): if len(all_config_paths) == 0: raise exceptions.SystemSetupError(f"At least one config base is required for car {name}") - variables = {} + variables: MutableMapping[str, str] = {} # car variables *always* take precedence over config base variables variables.update(all_config_base_vars) variables.update(all_car_vars) @@ -84,7 +95,7 @@ def __init__(self, root_path, entry_point): return Car(name, root_paths, all_config_paths, variables) -def list_plugins(cfg: types.Config): +def list_plugins(cfg: types.Config) -> None: plugins = PluginLoader(team_path(cfg)).plugins() if plugins: console.println("Available Elasticsearch plugins:\n") @@ -93,12 +104,16 @@ def list_plugins(cfg: types.Config): console.println("No Elasticsearch plugins are available.\n") -def load_plugin(repo, name, config_names, plugin_params=None): +def load_plugin( + repo: str, name: str, config_names: Optional[Collection[str]], plugin_params: Optional[Mapping[str, str]] = None +) -> "PluginDescriptor": return PluginLoader(repo).load_plugin(name, config_names, plugin_params) -def load_plugins(repo, plugin_names, plugin_params=None): - def name_and_config(p): +def load_plugins( + repo: str, plugin_names: Collection[str], plugin_params: Optional[Mapping[str, str]] = None +) -> Collection["PluginDescriptor"]: + def name_and_config(p: str) -> Tuple[str, Optional[Collection[str]]]: plugin_spec = p.split(":") if len(plugin_spec) == 1: return plugin_spec[0], None @@ -115,7 +130,7 @@ def name_and_config(p): return plugins -def team_path(cfg: types.Config): +def team_path(cfg: types.Config) -> str: root_path = cfg.opts("mechanic", "team.path", mandatory=False) if root_path: return root_path @@ -140,35 +155,38 @@ def team_path(cfg: types.Config): class CarLoader: - def __init__(self, team_root_path): + def __init__(self, team_root_path: str): self.cars_dir = _path_for(team_root_path, "cars") self.logger = logging.getLogger(__name__) - def car_names(self): - def __car_name(path): + def car_names(self) -> Iterator[str]: + def __car_name(path: str) -> str: p, _ = io.splitext(path) return io.basename(p) - def __is_car(path): + def __is_car(path: str) -> bool: _, extension = io.splitext(path) return extension == ".ini" return map(__car_name, filter(__is_car, os.listdir(self.cars_dir))) - def _car_file(self, name): + def _car_file(self, name: str) -> str: return os.path.join(self.cars_dir, f"{name}.ini") - def load_car(self, name, car_params=None): + def load_car(self, name: str, car_params: Optional[Mapping[str, Any]] = None) -> "CarDescriptor": car_config_file = self._car_file(name) if not io.exists(car_config_file): raise exceptions.SystemSetupError(f"Unknown car [{name}]. List the available cars with {PROGRAM_NAME} list cars.") config = self._config_loader(car_config_file) - root_paths = [] - config_paths = [] + root_paths: List[str] = [] + config_paths: List[str] = [] config_base_vars: Mapping[str, Any] = {} description = self._value(config, ["meta", "description"], default="") car_type = self._value(config, ["meta", "type"], default="car") - config_bases = self._value(config, ["config", "base"], default="").split(",") + config_base = self._value(config, ["config", "base"], default="") + assert config_base is not None, f"Car [{name}] does not define a config base." + assert isinstance(config_base, str), f"Car [{name}] defines an invalid config base [{config_base}]." + config_bases = config_base.split(",") for base in config_bases: if base: root_path = os.path.join(self.cars_dir, base) @@ -189,24 +207,27 @@ def load_car(self, name, car_params=None): return CarDescriptor(name, description, car_type, root_paths, config_paths, config_base_vars, variables) - def _config_loader(self, file_name): + def _config_loader(self, file_name: str) -> "configparser.ConfigParser": config = configparser.ConfigParser(interpolation=configparser.ExtendedInterpolation()) # Do not modify the case of option keys but read them as is config.optionxform = lambda optionstr: optionstr # type: ignore[method-assign] config.read(file_name) return config - def _value(self, cfg, section_path, default=None): - path = [section_path] if (isinstance(section_path, str)) else section_path + def _value( + self, cfg: "configparser.ConfigParser", section_path: Union[str, Collection[str]], default: Optional[str] = None + ) -> Optional[Mapping[str, Any]]: + path: Collection[str] = [section_path] if (isinstance(section_path, str)) else section_path current_cfg = cfg for k in path: + assert isinstance(current_cfg, dict), f"Expected a dict but got [{current_cfg}] instead." if k in current_cfg: current_cfg = current_cfg[k] else: return default return current_cfg - def _copy_section(self, cfg, section, target): + def _copy_section(self, cfg: "configparser.ConfigParser", section: str, target: MutableMapping[str, Any]) -> MutableMapping[str, Any]: if section in cfg.sections(): for k, v in cfg[section].items(): target[k] = v @@ -214,7 +235,16 @@ def _copy_section(self, cfg, section, target): class CarDescriptor: - def __init__(self, name, description, type, root_paths, config_paths, config_base_variables, variables): + def __init__( + self, + name: str, + description: str, + type: str, + root_paths: Collection[str], + config_paths: Collection[str], + config_base_variables: Mapping[str, str], + variables: Mapping[str, str], + ): self.name = name self.description = description self.type = type @@ -223,10 +253,10 @@ def __init__(self, name, description, type, root_paths, config_paths, config_bas self.config_base_variables = config_base_variables self.variables = variables - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, type(self)) and self.name == other.name @@ -265,40 +295,40 @@ def __init__( self.config_paths = config_paths self.variables = variables - def mandatory_var(self, name): + def mandatory_var(self, name: str) -> str: try: return self.variables[name] except KeyError: raise exceptions.SystemSetupError(f'Car "{self.name}" requires config key "{name}"') @property - def name(self): + def name(self) -> str: return "+".join(self.names) # Adapter method for BootstrapHookHandler @property - def config(self): + def config(self) -> str: return self.name @property - def safe_name(self): + def safe_name(self) -> str: return "_".join(self.names) - def __str__(self): + def __str__(self) -> str: return self.name class PluginLoader: - def __init__(self, team_root_path): + def __init__(self, team_root_path: str): self.plugins_root_path = _path_for(team_root_path, "plugins") self.logger = logging.getLogger(__name__) - def plugins(self, variables=None): + def plugins(self, variables: Optional[Mapping[str, str]] = None) -> List["PluginDescriptor"]: known_plugins = self._core_plugins(variables) + self._configured_plugins(variables) sorted(known_plugins, key=lambda p: p.name) return known_plugins - def _core_plugins(self, variables=None): + def _core_plugins(self, variables: Optional[Mapping[str, str]] = None) -> List["PluginDescriptor"]: core_plugins = [] core_plugins_path = os.path.join(self.plugins_root_path, "core-plugins.txt") if os.path.exists(core_plugins_path): @@ -310,7 +340,7 @@ def _core_plugins(self, variables=None): core_plugins.append(PluginDescriptor(name=values[0], core_plugin=True, variables=variables)) return core_plugins - def _configured_plugins(self, variables=None): + def _configured_plugins(self, variables: Optional[Mapping[str, str]] = None) -> List["PluginDescriptor"]: configured_plugins = [] # each directory is a plugin, each .ini is a config (just go one level deep) for entry in os.listdir(self.plugins_root_path): @@ -324,10 +354,10 @@ def _configured_plugins(self, variables=None): configured_plugins.append(PluginDescriptor(name=plugin_name, config=config, variables=variables)) return configured_plugins - def _plugin_file(self, name, config): + def _plugin_file(self, name: str, config: str) -> str: return os.path.join(self._plugin_root_path(name), "%s.ini" % config) - def _plugin_root_path(self, name): + def _plugin_root_path(self, name: str) -> str: return os.path.join(self.plugins_root_path, self._plugin_name_to_file(name)) # As we allow to store Python files in the plugin directory and the plugin directory also serves as the root path of the corresponding @@ -335,16 +365,18 @@ def _plugin_root_path(self, name): # need to switch from underscores to hyphens and vice versa. # # We are implicitly assuming that plugin names stick to the convention of hyphen separation to simplify implementation and usage a bit. - def _file_to_plugin_name(self, file_name): + def _file_to_plugin_name(self, file_name: str) -> str: return file_name.replace("_", "-") - def _plugin_name_to_file(self, plugin_name): + def _plugin_name_to_file(self, plugin_name: str) -> str: return plugin_name.replace("-", "_") - def _core_plugin(self, name, variables=None): + def _core_plugin(self, name: str, variables: Optional[Mapping[str, str]] = None) -> Optional["PluginDescriptor"]: return next((p for p in self._core_plugins(variables) if p.name == name and p.config is None), None) - def load_plugin(self, name, config_names, plugin_params=None): + def load_plugin( + self, name: str, config_names: Optional[Collection[str]], plugin_params: Optional[Mapping[str, str]] = None + ) -> "PluginDescriptor": if config_names is not None: self.logger.info("Loading plugin [%s] with configuration(s) [%s].", name, config_names) else: @@ -426,7 +458,15 @@ class PluginDescriptor: # name of the initial Python file to load for plugins. entry_point = "plugin" - def __init__(self, name, core_plugin=False, config=None, root_path=None, config_paths=None, variables=None): + def __init__( + self, + name: str, + core_plugin: bool = False, + config: Optional[Collection[str]] = None, + root_path: Optional[str] = None, + config_paths: Optional[Collection[str]] = None, + variables: Optional[Mapping[str, Any]] = None, + ): if config_paths is None: config_paths = [] if variables is None: @@ -438,27 +478,27 @@ def __init__(self, name, core_plugin=False, config=None, root_path=None, config_ self.config_paths = config_paths self.variables = variables - def __str__(self): - return "Plugin descriptor for [%s]" % self.name + def __str__(self) -> str: + return f"Plugin descriptor for [{self.name}]" - def __repr__(self): + def __repr__(self) -> str: r = [] for prop, value in vars(self).items(): r.append("%s = [%s]" % (prop, repr(value))) return ", ".join(r) @property - def moved_to_module(self): + def moved_to_module(self) -> bool: # For a BWC escape hatch we first check if the plugin is listed in rally-teams' "core-plugin.txt", # thus allowing users to override the teams path or revision to include the repository-s3/azure/gcs plugins in # "core-plugin.txt" # TODO: https://github.com/elastic/rally/issues/1622 return self.name in ["repository-s3", "repository-gcs", "repository-azure"] and not self.core_plugin - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) ^ hash(self.config) ^ hash(self.core_plugin) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, type(self)) and (self.name, self.config, self.core_plugin) == (other.name, other.config, other.core_plugin) @@ -466,14 +506,14 @@ class BootstrapPhase(Enum): post_install = 10 @classmethod - def valid(cls, name): + def valid(cls, name: str) -> bool: for n in BootstrapPhase.names(): if n == name: return True return False @classmethod - def names(cls): + def names(cls) -> Collection[str]: return [p.name for p in list(BootstrapPhase)] @@ -482,7 +522,7 @@ class BootstrapHookHandler: Responsible for loading and executing component-specific intitialization code. """ - def __init__(self, component, loader_class=modules.ComponentLoader): + def __init__(self, component: Any, loader_class: Callable = modules.ComponentLoader): """ Creates a new BootstrapHookHandler. @@ -497,13 +537,13 @@ def __init__(self, component, loader_class=modules.ComponentLoader): else: root_path = [self.component.root_path] self.loader = loader_class(root_path=root_path, component_entry_point=self.component.entry_point, recurse=False) - self.hooks = {} + self.hooks: MutableMapping[str, List[Callable]] = {} self.logger = logging.getLogger(__name__) - def can_load(self): + def can_load(self) -> bool: return self.loader.can_load() - def load(self): + def load(self) -> None: root_modules: Collection[ModuleType] = self.loader.load() try: # every module needs to have a register() method @@ -517,15 +557,16 @@ def load(self): self.logger.exception(msg) raise exceptions.SystemSetupError(msg) - def register(self, phase, hook): + def register(self, phase: str, hook: Callable) -> None: self.logger.info("Registering bootstrap hook [%s] for phase [%s] in component [%s]", hook.__name__, phase, self.component.name) if not BootstrapPhase.valid(phase): raise exceptions.SystemSetupError(f"Unknown bootstrap phase [{phase}]. Valid phases are: {BootstrapPhase.names()}.") if phase not in self.hooks: - self.hooks[phase] = [] + empty: List[Callable] = [] + self.hooks[phase] = empty self.hooks[phase].append(hook) - def invoke(self, phase, **kwargs): + def invoke(self, phase: str, **kwargs: Mapping[str, Any]) -> None: if phase in self.hooks: self.logger.info("Invoking phase [%s] for component [%s] in config [%s]", phase, self.component.name, self.component.config) for hook in self.hooks[phase]: diff --git a/esrally/utils/io.py b/esrally/utils/io.py index 6e4ff4f74..47bcafdda 100644 --- a/esrally/utils/io.py +++ b/esrally/utils/io.py @@ -26,7 +26,22 @@ import subprocess import tarfile import zipfile -from typing import AnyStr +from types import TracebackType +from typing import ( + IO, + Any, + AnyStr, + Callable, + Collection, + List, + Literal, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) import zstandard @@ -40,27 +55,31 @@ class FileSource: FileSource is a wrapper around a plain file which simplifies testing of file I/O calls. """ - def __init__(self, file_name, mode, encoding="utf-8"): + def __init__(self, file_name: str, mode: str, encoding: str = "utf-8"): self.file_name = file_name self.mode = mode self.encoding = encoding - self.f = None + self.f: Optional[IO[Any]] = None - def open(self): + def open(self) -> "FileSource": self.f = open(self.file_name, mode=self.mode, encoding=self.encoding) # allow for chaining return self - def seek(self, offset): + def seek(self, offset: int) -> None: + assert self.f is not None, "File is not open" self.f.seek(offset) - def read(self): + def read(self) -> bytes: + assert self.f is not None, "File is not open" return self.f.read() - def readline(self): + def readline(self) -> bytes: + assert self.f is not None, "File is not open" return self.f.readline() - def readlines(self, num_lines): + def readlines(self, num_lines: int) -> Sequence[bytes]: + assert self.f is not None, "File is not open" lines = [] f = self.f for _ in range(num_lines): @@ -70,19 +89,22 @@ def readlines(self, num_lines): lines.append(line) return lines - def close(self): + def close(self) -> None: + assert self.f is not None, "File is not open" self.f.close() self.f = None - def __enter__(self): + def __enter__(self) -> "FileSource": self.open() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> Literal[False]: self.close() return False - def __str__(self, *args, **kwargs): + def __str__(self, *args: Collection[Any], **kwargs: Mapping[str, Any]) -> str: return self.file_name @@ -91,14 +113,14 @@ class MmapSource: MmapSource is a wrapper around a memory-mapped file which simplifies testing of file I/O calls. """ - def __init__(self, file_name, mode, encoding="utf-8"): + def __init__(self, file_name: str, mode: str, encoding: str = "utf-8"): self.file_name = file_name self.mode = mode self.encoding = encoding - self.f = None - self.mm = None + self.f: Optional[IO[Any]] = None + self.mm: Optional[mmap.mmap] = None - def open(self): + def open(self) -> "MmapSource": self.f = open(self.file_name, mode="r+b") self.mm = mmap.mmap(self.f.fileno(), 0, access=mmap.ACCESS_READ) self.mm.madvise(mmap.MADV_SEQUENTIAL) @@ -106,16 +128,20 @@ def open(self): # allow for chaining return self - def seek(self, offset): + def seek(self, offset: int) -> None: + assert self.mm is not None, "Source is not open" self.mm.seek(offset) - def read(self): + def read(self) -> bytes: + assert self.mm is not None, "Source is not open" return self.mm.read() - def readline(self): + def readline(self) -> bytes: + assert self.mm is not None, "Source is not open" return self.mm.readline() - def readlines(self, num_lines): + def readlines(self, num_lines: int) -> Sequence[bytes]: + assert self.mm is not None, "Source is not open" lines = [] mm = self.mm for _ in range(num_lines): @@ -125,21 +151,25 @@ def readlines(self, num_lines): lines.append(line) return lines - def close(self): + def close(self) -> None: + assert self.mm is not None, "Source is not open" self.mm.close() self.mm = None + assert self.f is not None, "File is not open" self.f.close() self.f = None - def __enter__(self): + def __enter__(self) -> "MmapSource": self.open() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> Literal[False]: self.close() return False - def __str__(self, *args, **kwargs): + def __str__(self, *args: Collection[Any], **kwargs: Mapping[str, Any]) -> str: return self.file_name @@ -150,10 +180,10 @@ class DictStringFileSourceFactory: It is intended for scenarios where multiple files may be read by client code. """ - def __init__(self, name_to_contents): + def __init__(self, name_to_contents: Mapping[str, Sequence[str]]): self.name_to_contents = name_to_contents - def __call__(self, name, mode, encoding="utf-8"): + def __call__(self, name: str, mode: str, encoding: str = "utf-8") -> "StringAsFileSource": return StringAsFileSource(self.name_to_contents[name], mode, encoding) @@ -163,7 +193,7 @@ class StringAsFileSource: be used in production code. """ - def __init__(self, contents, mode, encoding="utf-8"): + def __init__(self, contents: Sequence[str], mode: str, encoding: str = "utf-8"): """ :param contents: The file contents as an array of strings. Each item in the array should correspond to one line. :param mode: The file mode. It is ignored in this implementation but kept to implement the same interface as ``FileSource``. @@ -173,20 +203,20 @@ def __init__(self, contents, mode, encoding="utf-8"): self.current_index = 0 self.opened = False - def open(self): + def open(self) -> "StringAsFileSource": self.opened = True return self - def seek(self, offset): + def seek(self, offset: int) -> None: self._assert_opened() if offset != 0: raise AssertionError("StringAsFileSource does not support random seeks") - def read(self): + def read(self) -> str: self._assert_opened() return "\n".join(self.contents) - def readline(self): + def readline(self) -> str: self._assert_opened() if self.current_index >= len(self.contents): return "" @@ -194,7 +224,7 @@ def readline(self): self.current_index += 1 return line - def readlines(self, num_lines): + def readlines(self, num_lines: int) -> Sequence[str]: lines = [] for _ in range(num_lines): line = self.readline() @@ -203,23 +233,25 @@ def readlines(self, num_lines): lines.append(line) return lines - def close(self): + def close(self) -> None: self._assert_opened() - self.contents = None + self.contents = [] self.opened = False - def _assert_opened(self): + def _assert_opened(self) -> None: assert self.opened - def __enter__(self): + def __enter__(self) -> "StringAsFileSource": self.open() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> Literal[False]: self.close() return False - def __str__(self, *args, **kwargs): + def __str__(self, *args: Collection[Any], **kwargs: Mapping[str, Any]) -> str: return "StringAsFileSource" @@ -228,20 +260,20 @@ class ZstAdapter: Adapter class to make the zstandard API work with Rally's decompression abstractions """ - def __init__(self, path): + def __init__(self, path: str): self.fh = open(path, "rb") self.dctx = zstandard.ZstdDecompressor() self.reader = self.dctx.stream_reader(self.fh) - def read(self, size): + def read(self, size: int) -> bytes: return self.reader.read(size) - def close(self): + def close(self) -> None: self.reader.close() self.fh.close() -def ensure_dir(directory, mode=0o777): +def ensure_dir(directory: str, mode: int = 0o777) -> None: """ Ensure that the provided directory and all of its parent directories exist. This function is safe to execute on existing directories (no op). @@ -253,7 +285,7 @@ def ensure_dir(directory, mode=0o777): os.makedirs(directory, mode, exist_ok=True) -def _zipdir(source_directory, archive): +def _zipdir(source_directory: str, archive: zipfile.ZipFile) -> None: for root, _, files in os.walk(source_directory): for file in files: archive.write( @@ -262,7 +294,7 @@ def _zipdir(source_directory, archive): ) -def is_archive(name): +def is_archive(name: str) -> bool: """ :param name: File name to check. Can be either just the file name or optionally also an absolute path. :return: True iff the given file name is an archive that is also recognized for decompression by Rally. @@ -271,7 +303,7 @@ def is_archive(name): return ext in SUPPORTED_ARCHIVE_FORMATS -def is_executable(name): +def is_executable(name: str) -> bool: """ :param name: File name to check. :return: True iff given file name is executable and in PATH, all other cases False. @@ -280,7 +312,7 @@ def is_executable(name): return shutil.which(name) is not None -def compress(source_directory, archive_name): +def compress(source_directory: str, archive_name: str) -> None: """ Compress a directory tree. @@ -291,7 +323,7 @@ def compress(source_directory, archive_name): _zipdir(source_directory, archive) -def decompress(zip_name, target_directory): +def decompress(zip_name: str, target_directory: str) -> None: """ Decompresses the provided archive to the target directory. The following file extensions are supported: @@ -315,23 +347,23 @@ def decompress(zip_name, target_directory): _do_decompress(target_directory, zipfile.ZipFile(zip_name)) elif extension == ".bz2": decompressor_args = ["pbzip2", "-d", "-k", "-m10000", "-c"] - decompressor_lib = bz2.open - _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib) + decompressor_lib_bz2 = bz2.open + _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib_bz2) elif extension == ".zst": decompressor_args = ["pzstd", "-f", "-d", "-c"] - decompressor_lib = ZstAdapter - _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib) + decompressor_lib_zst = ZstAdapter + _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib_zst) elif extension == ".gz": decompressor_args = ["pigz", "-d", "-k", "-c"] - decompressor_lib = gzip.open - _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib) + decompressor_lib_gzip = gzip.open + _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib_gzip) elif extension in [".tar", ".tar.gz", ".tgz", ".tar.bz2"]: _do_decompress(target_directory, tarfile.open(zip_name)) else: raise RuntimeError("Unsupported file extension [%s]. Cannot decompress [%s]" % (extension, zip_name)) -def _do_decompress_manually(target_directory, filename, decompressor_args, decompressor_lib): +def _do_decompress_manually(target_directory: str, filename: str, decompressor_args: List[str], decompressor_lib: Callable) -> None: decompressor_bin = decompressor_args[0] base_path_without_extension = basename(splitext(filename)[0]) @@ -346,7 +378,9 @@ def _do_decompress_manually(target_directory, filename, decompressor_args, decom _do_decompress_manually_with_lib(target_directory, filename, decompressor_lib(filename)) -def _do_decompress_manually_external(target_directory, filename, base_path_without_extension, decompressor_args): +def _do_decompress_manually_external( + target_directory: str, filename: str, base_path_without_extension: str, decompressor_args: List[str] +) -> bool: with open(os.path.join(target_directory, base_path_without_extension), "wb") as new_file: try: subprocess.run(decompressor_args + [filename], stdout=new_file, stderr=subprocess.PIPE, check=True) @@ -358,7 +392,7 @@ def _do_decompress_manually_external(target_directory, filename, base_path_witho return True -def _do_decompress_manually_with_lib(target_directory, filename, compressed_file): +def _do_decompress_manually_with_lib(target_directory: str, filename: str, compressed_file: IO[Any]) -> None: path_without_extension = basename(splitext(filename)[0]) ensure_dir(target_directory) @@ -370,29 +404,34 @@ def _do_decompress_manually_with_lib(target_directory, filename, compressed_file compressed_file.close() -def _do_decompress(target_directory, compressed_file): +def _do_decompress(target_directory: str, compressed_file: Union[zipfile.ZipFile, tarfile.TarFile]) -> None: try: compressed_file.extractall(path=target_directory) except BaseException: - raise RuntimeError("Could not decompress provided archive [%s]" % compressed_file.filename) + if isinstance(compressed_file, zipfile.ZipFile): + raise RuntimeError( + f"Could not decompress provided archive [{compressed_file.filename}]. Please check if it is a valid zip file." + ) + if isinstance(compressed_file, tarfile.TarFile): + raise RuntimeError(f"Could not decompress provided archive [{compressed_file.name!r}]. Please check if it is a valid tar file.") finally: compressed_file.close() # just in a dedicated method to ease mocking -def dirname(path: AnyStr): +def dirname(path: AnyStr) -> AnyStr: return os.path.dirname(path) -def basename(path: AnyStr): +def basename(path: AnyStr) -> AnyStr: return os.path.basename(path) -def exists(path: AnyStr): +def exists(path: AnyStr) -> bool: return os.path.exists(path) -def normalize_path(path, cwd="."): +def normalize_path(path: AnyStr, cwd: Any = ".") -> AnyStr: """ Normalizes a path by removing redundant "../" and also expanding the "~" character to the user home directory. :param path: A possibly non-normalized path. @@ -407,7 +446,7 @@ def normalize_path(path, cwd="."): return normalized -def escape_path(path): +def escape_path(path: str) -> str: """ Escapes any characters that might be problematic in shell interactions. @@ -417,7 +456,7 @@ def escape_path(path): return path.replace("\\", "\\\\") -def splitext(file_name): +def splitext(file_name: str) -> Tuple[str, str]: if file_name.endswith(".tar.gz"): return file_name[0:-7], file_name[-7:] elif file_name.endswith(".tar.bz2"): @@ -426,7 +465,7 @@ def splitext(file_name): return os.path.splitext(file_name) -def has_extension(file_name, extension): +def has_extension(file_name: str, extension: str) -> bool: """ Checks whether the given file name has the given extension. @@ -444,7 +483,7 @@ class FileOffsetTable: data file. This helps bulk-indexing clients to advance quickly to a certain position in a large data file. """ - def __init__(self, data_file_path, offset_table_path, mode): + def __init__(self, data_file_path: str, offset_table_path: str, mode: str): """ Creates a new FileOffsetTable instance. The constructor should not be called directly but instead the respective factory methods should be used. @@ -457,34 +496,35 @@ def __init__(self, data_file_path, offset_table_path, mode): self.data_file_path = data_file_path self.offset_table_path = offset_table_path self.mode = mode - self.offset_file = None + self.offset_file: Optional[IO[Any]] = None - def exists(self): + def exists(self) -> bool: """ :return: True iff the file offset table already exists. """ return os.path.exists(self.offset_table_path) - def is_valid(self): + def is_valid(self) -> bool: """ :return: True iff the file offset table exists and it is up-to-date. """ return self.exists() and os.path.getmtime(self.offset_table_path) >= os.path.getmtime(self.data_file_path) - def __enter__(self): + def __enter__(self) -> "FileOffsetTable": self.offset_file = open(self.offset_table_path, self.mode) return self - def add_offset(self, line_number, offset): + def add_offset(self, line_number: int, offset: int) -> None: """ Adds a new offset mapping to the file offset table. This method has to be called inside a context-manager block. :param line_number: A line number to add. :param offset: The corresponding offset in bytes. """ + assert self.offset_file is not None, "File offset table must be opened in a context manager block." print(f"{line_number};{offset}", file=self.offset_file) - def find_closest_offset(self, target_line_number): + def find_closest_offset(self, target_line_number: int) -> Tuple[int, int]: """ Determines the offset in bytes for the line L in the corresponding data file with the following properties: @@ -498,6 +538,7 @@ def find_closest_offset(self, target_line_number): prior_offset = 0 prior_remaining_lines = target_line_number + assert self.offset_file is not None, "File offset table must be opened in a context manager block." for line in self.offset_file: line_number, offset_in_bytes = (int(i) for i in line.strip().split(";")) if line_number <= target_line_number: @@ -508,13 +549,16 @@ def find_closest_offset(self, target_line_number): return prior_offset, prior_remaining_lines - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] + ) -> Literal[False]: + assert self.offset_file is not None, "File offset table must be opened in a context manager block." self.offset_file.close() self.offset_file = None return False @classmethod - def create_for_data_file(cls, data_file_path): + def create_for_data_file(cls, data_file_path: str) -> "FileOffsetTable": """ Factory method to create a new file offset table. @@ -523,7 +567,7 @@ def create_for_data_file(cls, data_file_path): return cls(data_file_path, f"{data_file_path}.offset", "wt") @classmethod - def read_for_data_file(cls, data_file_path): + def read_for_data_file(cls, data_file_path: str) -> "FileOffsetTable": """ Factory method to read from an existing file offset table. @@ -533,7 +577,7 @@ def read_for_data_file(cls, data_file_path): return cls(data_file_path, f"{data_file_path}.offset", "rt") @staticmethod - def remove(data_file_path): + def remove(data_file_path: str) -> None: """ Removes a file offset table for the provided data path. @@ -542,7 +586,7 @@ def remove(data_file_path): os.remove(f"{data_file_path}.offset") -def prepare_file_offset_table(data_file_path): +def prepare_file_offset_table(data_file_path: str) -> Optional[int]: """ Creates a file that contains a mapping from line numbers to file offsets for the provided path. This file is used internally by #skip_lines(data_file_path, data_file) to speed up line skipping. @@ -569,7 +613,7 @@ def prepare_file_offset_table(data_file_path): return None -def remove_file_offset_table(data_file_path): +def remove_file_offset_table(data_file_path: str) -> None: """ Attempts to remove the file offset table for the provided data path. @@ -579,7 +623,7 @@ def remove_file_offset_table(data_file_path): FileOffsetTable.remove(data_file_path) -def skip_lines(data_file_path, data_file, number_of_lines_to_skip): +def skip_lines(data_file_path: str, data_file: IO[Any], number_of_lines_to_skip: int) -> None: """ Skips the first `number_of_lines_to_skip` lines in `data_file` as a side effect. @@ -607,7 +651,7 @@ def skip_lines(data_file_path, data_file, number_of_lines_to_skip): data_file.readline() -def get_size(start_path="."): +def get_size(start_path: str = ".") -> int: total_size = 0 for dirpath, _, filenames in os.walk(start_path): for f in filenames: diff --git a/esrally/utils/modules.py b/esrally/utils/modules.py index a7b0bc83a..757445767 100644 --- a/esrally/utils/modules.py +++ b/esrally/utils/modules.py @@ -20,7 +20,7 @@ import os import sys from types import ModuleType -from typing import Collection, Union +from typing import Collection, Generator, Tuple, Union from esrally import exceptions from esrally.utils import io @@ -50,7 +50,7 @@ def __init__(self, root_path: Union[str, Collection[str]], component_entry_point self.recurse = recurse self.logger = logging.getLogger(__name__) - def _modules(self, module_paths: Collection[str], component_name: str, root_path: str): + def _modules(self, module_paths: Collection[str], component_name: str, root_path: str) -> Generator[Tuple[str, str], None, None]: for path in module_paths: for filename in os.listdir(path): name, ext = os.path.splitext(filename) @@ -61,7 +61,7 @@ def _modules(self, module_paths: Collection[str], component_name: str, root_path module_name = "%s.%s" % (component_name, root_relative_path.replace(os.path.sep, ".")) yield module_name, file_absolute_path - def _load_component(self, component_name: str, module_dirs: Collection[str], root_path: str): + def _load_component(self, component_name: str, module_dirs: Collection[str], root_path: str) -> ModuleType: # precondition: A module with this name has to exist provided that the caller has called #can_load() before. root_module_name = "%s.%s" % (component_name, self.component_entry_point) for name, p in self._modules(module_dirs, component_name, root_path): diff --git a/esrally/utils/process.py b/esrally/utils/process.py index c26c4c0f2..76dfb5d42 100644 --- a/esrally/utils/process.py +++ b/esrally/utils/process.py @@ -20,7 +20,7 @@ import shlex import subprocess import time -from typing import IO, Callable, Dict, List, Optional, Union +from typing import IO, Callable, List, Mapping, Optional, Union import psutil @@ -38,7 +38,7 @@ def run_subprocess(command_line: str) -> int: return subprocess.call(command_line, shell=True) -def run_subprocess_with_output(command_line: str, env: Dict[str, str] = None) -> List[str]: +def run_subprocess_with_output(command_line: str, env: Optional[Mapping[str, str]] = None) -> List[str]: logger = logging.getLogger(__name__) logger.debug("Running subprocess [%s] with output.", command_line) command_line_args = shlex.split(command_line) @@ -46,6 +46,7 @@ def run_subprocess_with_output(command_line: str, env: Dict[str, str] = None) -> has_output = True lines = [] while has_output: + assert command_line_process.stdout is not None, "stdout is None" line = command_line_process.stdout.readline() if line: lines.append(line.decode("UTF-8").strip()) @@ -72,10 +73,10 @@ def exit_status_as_bool(runnable: Callable[[], int], quiet: bool = False) -> boo def run_subprocess_with_logging( command_line: str, - header: str = None, + header: Optional[str] = None, level: LogLevel = logging.INFO, stdin: Optional[Union[FileId, IO[bytes]]] = None, - env: Dict[str, str] = None, + env: Optional[Mapping[str, str]] = None, detach: bool = False, ) -> int: """ @@ -117,10 +118,10 @@ def run_subprocess_with_logging( def run_subprocess_with_logging_and_output( command_line: str, - header: str = None, + header: Optional[str] = None, level: LogLevel = logging.INFO, - stdin: FileId = None, - env: Dict[str, str] = None, + stdin: Optional[FileId] = None, + env: Optional[Mapping[str, str]] = None, detach: bool = False, ) -> subprocess.CompletedProcess: """ @@ -173,7 +174,7 @@ def is_rally_process(p: psutil.Process) -> bool: def find_all_other_rally_processes() -> List[psutil.Process]: - others = [] + others: List[psutil.Process] = [] for_all_other_processes(is_rally_process, others.append) return others @@ -187,7 +188,7 @@ def redact_cmdline(cmdline: list) -> List[str]: def kill_all(predicate: Callable[[psutil.Process], bool]) -> None: - def kill(p: psutil.Process): + def kill(p: psutil.Process) -> None: logging.getLogger(__name__).info( "Killing lingering process with PID [%s] and command line [%s].", p.pid, redact_cmdline(p.cmdline()) ) diff --git a/pyproject.toml b/pyproject.toml index 627e9ecd3..0cf898796 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,13 +10,13 @@ path = "esrally/_version.py" name = "esrally" dynamic = ["version"] authors = [ - {name="Daniel Mitterdorfer", email="daniel.mitterdorfer@gmail.com"}, + { name = "Daniel Mitterdorfer", email = "daniel.mitterdorfer@gmail.com" }, ] description = "Macrobenchmarking framework for Elasticsearch" readme = "README.md" -license = {text = "Apache License 2.0"} +license = { text = "Apache License 2.0" } requires-python = ">=3.8" -classifiers=[ +classifiers = [ "Topic :: System :: Benchmark", "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: Apache Software License", @@ -81,7 +81,7 @@ dependencies = [ # License: Apache 2.0 "google-auth==1.22.1", # License: BSD - "zstandard==0.21.0" + "zstandard==0.21.0", ] [project.optional-dependencies] @@ -112,6 +112,9 @@ develop = [ "pylint==3.1.0", "trustme==0.9.0", "GitPython==3.1.30", + # mypy + "types-psutil==5.9.4", + "types-tabulate==0.8.9", ] [project.scripts] @@ -181,17 +184,16 @@ disable_error_code = [ "union-attr", "var-annotated", ] -files = [ - "esrally/", - "it/", - "tests/", -] +files = ["esrally/", "it/", "tests/"] [[tool.mypy.overrides]] module = [ "esrally.mechanic.team", "esrally.utils.modules", + "esrally.utils.io", + "esrally.utils.process", ] +disallow_incomplete_defs = true # this should be a copy of disabled_error_code from above enable_error_code = [ "assignment", From 0660ab9172359c1a98f59b5de632627185601915 Mon Sep 17 00:00:00 2001 From: favilo Date: Tue, 2 Jul 2024 09:26:41 -0700 Subject: [PATCH 14/26] Make stdin more generic --- esrally/utils/process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/esrally/utils/process.py b/esrally/utils/process.py index 76dfb5d42..b6d8d9f1c 100644 --- a/esrally/utils/process.py +++ b/esrally/utils/process.py @@ -120,7 +120,7 @@ def run_subprocess_with_logging_and_output( command_line: str, header: Optional[str] = None, level: LogLevel = logging.INFO, - stdin: Optional[FileId] = None, + stdin: Optional[Union[FileId, IO[bytes]]] = None, env: Optional[Mapping[str, str]] = None, detach: bool = False, ) -> subprocess.CompletedProcess: From f6c7aa830eab1a0de9d3abd2c4e9ac7f201d25e0 Mon Sep 17 00:00:00 2001 From: favilo Date: Tue, 2 Jul 2024 09:40:44 -0700 Subject: [PATCH 15/26] Shouldn't have used an assert there --- esrally/mechanic/team.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/esrally/mechanic/team.py b/esrally/mechanic/team.py index 9b1fa678c..ddaa48d85 100644 --- a/esrally/mechanic/team.py +++ b/esrally/mechanic/team.py @@ -180,13 +180,19 @@ def load_car(self, name: str, car_params: Optional[Mapping[str, Any]] = None) -> config = self._config_loader(car_config_file) root_paths: List[str] = [] config_paths: List[str] = [] - config_base_vars: Mapping[str, Any] = {} + config_base_vars: MutableMapping[str, Any] = {} + description = self._value(config, ["meta", "description"], default="") + assert isinstance(description, str), f"Car [{name}] defines an invalid description [{description}]." + car_type = self._value(config, ["meta", "type"], default="car") + assert isinstance(car_type, str), f"Car [{name}] defines an invalid type [{car_type}]." + config_base = self._value(config, ["config", "base"], default="") assert config_base is not None, f"Car [{name}] does not define a config base." assert isinstance(config_base, str), f"Car [{name}] defines an invalid config base [{config_base}]." config_bases = config_base.split(",") + for base in config_bases: if base: root_path = os.path.join(self.cars_dir, base) @@ -216,12 +222,11 @@ def _config_loader(self, file_name: str) -> "configparser.ConfigParser": def _value( self, cfg: "configparser.ConfigParser", section_path: Union[str, Collection[str]], default: Optional[str] = None - ) -> Optional[Mapping[str, Any]]: + ) -> Optional[Union[str, Mapping[str, Any]]]: path: Collection[str] = [section_path] if (isinstance(section_path, str)) else section_path - current_cfg = cfg + current_cfg: Union["configparser.ConfigParser", Mapping[str, Any], str] = cfg for k in path: - assert isinstance(current_cfg, dict), f"Expected a dict but got [{current_cfg}] instead." - if k in current_cfg: + if not isinstance(current_cfg, str) and k in current_cfg: current_cfg = current_cfg[k] else: return default From 4d0a3e2fdf871675749c8e62d87f5641ae33aae7 Mon Sep 17 00:00:00 2001 From: favilo Date: Tue, 2 Jul 2024 10:59:53 -0700 Subject: [PATCH 16/26] Replace Generator with Iterator --- esrally/utils/modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/esrally/utils/modules.py b/esrally/utils/modules.py index 757445767..69ec87d22 100644 --- a/esrally/utils/modules.py +++ b/esrally/utils/modules.py @@ -20,7 +20,7 @@ import os import sys from types import ModuleType -from typing import Collection, Generator, Tuple, Union +from typing import Collection, Iterator, Tuple, Union from esrally import exceptions from esrally.utils import io @@ -50,7 +50,7 @@ def __init__(self, root_path: Union[str, Collection[str]], component_entry_point self.recurse = recurse self.logger = logging.getLogger(__name__) - def _modules(self, module_paths: Collection[str], component_name: str, root_path: str) -> Generator[Tuple[str, str], None, None]: + def _modules(self, module_paths: Collection[str], component_name: str, root_path: str) -> Iterator[Tuple[str, str]]: for path in module_paths: for filename in os.listdir(path): name, ext = os.path.splitext(filename) From d21ad4b56031ffebf40b95113904a11b4422caf9 Mon Sep 17 00:00:00 2001 From: favilo Date: Wed, 3 Jul 2024 09:10:59 -0700 Subject: [PATCH 17/26] testing $VIRTUAL_ENV for python_executable. This will let it use the correct python environment by default. Ideally. I will revert if this breaks CI again. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 0cf898796..6f01e6f7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,6 +156,7 @@ xfail_strict = true # is to keep "arg-type" error code, while disabling other error codes. # Ref: https://github.com/elastic/rally/pull/1798 [tool.mypy] +python_executable="$VIRTUAL_ENV/bin/python" python_version = "3.8" # subset of "strict", kept at global config level as some of the options are # supported only at this level From 2058474dc796f3e81f7bc4379041e82a57e1c414 Mon Sep 17 00:00:00 2001 From: Grzegorz Banasiak Date: Mon, 8 Jul 2024 19:36:12 +0200 Subject: [PATCH 18/26] Switch to local repo for mypy in pre-commit config --- .pre-commit-config.yaml | 13 +++++-------- pyproject.toml | 1 + 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8649788f6..8786c153e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,17 +18,14 @@ repos: hooks: - id: isort - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.6.1 + - repo: local hooks: - id: mypy - additional_dependencies: [ - "elasticsearch[async]==8.6.1", - "elastic-transport==8.4.1", - "types-tabulate==0.8.9", - ] + name: mypy + entry: mypy + language: system + types: [python] args: [ - "--ignore-missing-imports", "--config", "pyproject.toml" ] diff --git a/pyproject.toml b/pyproject.toml index 6f01e6f7e..f05634932 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,6 +113,7 @@ develop = [ "trustme==0.9.0", "GitPython==3.1.30", # mypy + "mypy==1.10.1", "types-psutil==5.9.4", "types-tabulate==0.8.9", ] From 78e5c3f4ec439b0cee0faff257f34be61aacf771 Mon Sep 17 00:00:00 2001 From: favilo Date: Mon, 8 Jul 2024 10:49:08 -0700 Subject: [PATCH 19/26] Remove python_executable since it breaks GH actions. --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f05634932..3e0999378 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,7 +157,6 @@ xfail_strict = true # is to keep "arg-type" error code, while disabling other error codes. # Ref: https://github.com/elastic/rally/pull/1798 [tool.mypy] -python_executable="$VIRTUAL_ENV/bin/python" python_version = "3.8" # subset of "strict", kept at global config level as some of the options are # supported only at this level From 4dd09095263e0c6d6629186960b682394e49cf9c Mon Sep 17 00:00:00 2001 From: favilo Date: Mon, 8 Jul 2024 11:48:48 -0700 Subject: [PATCH 20/26] Change to newer version of `black` so we don't have unnecessary blank lines between overridden type definitions --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8786c153e..ba7fa22bf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: ] - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 24.4.2 hooks: - id: black From b94c1a11b11c2692ab773704685e710fd7dbf589 Mon Sep 17 00:00:00 2001 From: favilo Date: Mon, 8 Jul 2024 12:00:03 -0700 Subject: [PATCH 21/26] Add type annotations with overrides to allow setting the default --- esrally/utils/io.py | 12 +++++++++++- pyproject.toml | 3 +++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/esrally/utils/io.py b/esrally/utils/io.py index 47bcafdda..435073a02 100644 --- a/esrally/utils/io.py +++ b/esrally/utils/io.py @@ -41,6 +41,7 @@ Tuple, Type, Union, + overload, ) import zstandard @@ -431,7 +432,16 @@ def exists(path: AnyStr) -> bool: return os.path.exists(path) -def normalize_path(path: AnyStr, cwd: Any = ".") -> AnyStr: +@overload +def normalize_path(path: str) -> str: ... +@overload +def normalize_path(path: str, cwd: str = ".") -> str: ... +@overload +def normalize_path(path: bytes) -> bytes: ... +@overload +def normalize_path(path: bytes, cwd: bytes = b".") -> bytes: ... +def normalize_path(path, cwd="."): + # This is a bug in mypy, see https://github.com/python/mypy/issues/3737 """ Normalizes a path by removing redundant "../" and also expanding the "~" character to the user home directory. :param path: A possibly non-normalized path. diff --git a/pyproject.toml b/pyproject.toml index 3e0999378..0700f549d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,9 +113,12 @@ develop = [ "trustme==0.9.0", "GitPython==3.1.30", # mypy + "boto3-stubs==1.26.125", "mypy==1.10.1", "types-psutil==5.9.4", "types-tabulate==0.8.9", + "types-urllib3==1.26.19", + "types-requests<2.32.0", ] [project.scripts] From 56ef7e5224af543f5f8b2aaa06324c56926b3d23 Mon Sep 17 00:00:00 2001 From: favilo Date: Mon, 8 Jul 2024 12:29:13 -0700 Subject: [PATCH 22/26] Remove `Any` from file return types. This ensures we are using a single return type for `read()` and friends. --- esrally/utils/io.py | 21 +++++++++++---------- pyproject.toml | 1 + 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/esrally/utils/io.py b/esrally/utils/io.py index 435073a02..bd3775f01 100644 --- a/esrally/utils/io.py +++ b/esrally/utils/io.py @@ -33,6 +33,7 @@ AnyStr, Callable, Collection, + Generic, List, Literal, Mapping, @@ -51,7 +52,7 @@ SUPPORTED_ARCHIVE_FORMATS = [".zip", ".bz2", ".gz", ".tar", ".tar.gz", ".tgz", ".tar.bz2", ".zst"] -class FileSource: +class FileSource(Generic[AnyStr]): """ FileSource is a wrapper around a plain file which simplifies testing of file I/O calls. """ @@ -60,7 +61,7 @@ def __init__(self, file_name: str, mode: str, encoding: str = "utf-8"): self.file_name = file_name self.mode = mode self.encoding = encoding - self.f: Optional[IO[Any]] = None + self.f: Optional[IO[AnyStr]] = None def open(self) -> "FileSource": self.f = open(self.file_name, mode=self.mode, encoding=self.encoding) @@ -71,15 +72,15 @@ def seek(self, offset: int) -> None: assert self.f is not None, "File is not open" self.f.seek(offset) - def read(self) -> bytes: + def read(self) -> AnyStr: assert self.f is not None, "File is not open" return self.f.read() - def readline(self) -> bytes: + def readline(self) -> AnyStr: assert self.f is not None, "File is not open" return self.f.readline() - def readlines(self, num_lines: int) -> Sequence[bytes]: + def readlines(self, num_lines: int) -> Sequence[AnyStr]: assert self.f is not None, "File is not open" lines = [] f = self.f @@ -118,7 +119,7 @@ def __init__(self, file_name: str, mode: str, encoding: str = "utf-8"): self.file_name = file_name self.mode = mode self.encoding = encoding - self.f: Optional[IO[Any]] = None + self.f: Optional[IO[bytes]] = None self.mm: Optional[mmap.mmap] = None def open(self) -> "MmapSource": @@ -393,7 +394,7 @@ def _do_decompress_manually_external( return True -def _do_decompress_manually_with_lib(target_directory: str, filename: str, compressed_file: IO[Any]) -> None: +def _do_decompress_manually_with_lib(target_directory: str, filename: str, compressed_file: IO[bytes]) -> None: path_without_extension = basename(splitext(filename)[0]) ensure_dir(target_directory) @@ -506,7 +507,7 @@ def __init__(self, data_file_path: str, offset_table_path: str, mode: str): self.data_file_path = data_file_path self.offset_table_path = offset_table_path self.mode = mode - self.offset_file: Optional[IO[Any]] = None + self.offset_file: Optional[IO[bytes]] = None def exists(self) -> bool: """ @@ -550,7 +551,7 @@ def find_closest_offset(self, target_line_number: int) -> Tuple[int, int]: assert self.offset_file is not None, "File offset table must be opened in a context manager block." for line in self.offset_file: - line_number, offset_in_bytes = (int(i) for i in line.strip().split(";")) + line_number, offset_in_bytes = (int(i) for i in line.strip().split(b";")) if line_number <= target_line_number: prior_offset = offset_in_bytes prior_remaining_lines = target_line_number - line_number @@ -633,7 +634,7 @@ def remove_file_offset_table(data_file_path: str) -> None: FileOffsetTable.remove(data_file_path) -def skip_lines(data_file_path: str, data_file: IO[Any], number_of_lines_to_skip: int) -> None: +def skip_lines(data_file_path: str, data_file: IO[AnyStr], number_of_lines_to_skip: int) -> None: """ Skips the first `number_of_lines_to_skip` lines in `data_file` as a side effect. diff --git a/pyproject.toml b/pyproject.toml index 0700f549d..3592da9c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,6 +119,7 @@ develop = [ "types-tabulate==0.8.9", "types-urllib3==1.26.19", "types-requests<2.32.0", + "types-jsonschema==3.2.0", ] [project.scripts] From 4981f13f10286fd2e048e70b607801ed91cae45a Mon Sep 17 00:00:00 2001 From: favilo Date: Mon, 8 Jul 2024 12:49:06 -0700 Subject: [PATCH 23/26] new black version required new formatting --- esrally/driver/scheduler.py | 6 ++---- esrally/types.py | 15 +++++---------- tests/track/loader_test.py | 2 +- 3 files changed, 8 insertions(+), 15 deletions(-) diff --git a/esrally/driver/scheduler.py b/esrally/driver/scheduler.py index 4cd9c7b4b..05c591ad3 100644 --- a/esrally/driver/scheduler.py +++ b/esrally/driver/scheduler.py @@ -166,8 +166,7 @@ def remove_scheduler(name): class SimpleScheduler(ABC): @abstractmethod - def next(self, current): - ... + def next(self, current): ... class Scheduler(ABC): @@ -178,8 +177,7 @@ def after_request(self, now, weight, unit, request_meta_data): pass @abstractmethod - def next(self, current): - ... + def next(self, current): ... # Deprecated diff --git a/esrally/types.py b/esrally/types.py index d64dbc896..bb6e8c2b1 100644 --- a/esrally/types.py +++ b/esrally/types.py @@ -162,17 +162,12 @@ class Config(Protocol): - def add(self, scope, section: Section, key: Key, value: Any) -> None: - ... + def add(self, scope, section: Section, key: Key, value: Any) -> None: ... - def add_all(self, source: _Config, section: Section) -> None: - ... + def add_all(self, source: _Config, section: Section) -> None: ... - def opts(self, section: Section, key: Key, default_value=None, mandatory: bool = True) -> Any: - ... + def opts(self, section: Section, key: Key, default_value=None, mandatory: bool = True) -> Any: ... - def all_opts(self, section: Section) -> dict: - ... + def all_opts(self, section: Section) -> dict: ... - def exists(self, section: Section, key: Key) -> bool: - ... + def exists(self, section: Section, key: Key) -> bool: ... diff --git a/tests/track/loader_test.py b/tests/track/loader_test.py index 20b9f64a3..56cc65911 100644 --- a/tests/track/loader_test.py +++ b/tests/track/loader_test.py @@ -2758,7 +2758,7 @@ def test_parse_valid_without_types(self): "indices": [ { "name": "index-historical", - "body": "body.json" + "body": "body.json", # no type information here } ], From e30278532f5475eb1af5b3aeb445ad35c65b1583 Mon Sep 17 00:00:00 2001 From: favilo Date: Mon, 8 Jul 2024 13:03:27 -0700 Subject: [PATCH 24/26] Add typing_extensions, so we can more easily return a generic type from builder methods --- esrally/utils/io.py | 22 +++++++++++++--------- pyproject.toml | 1 + 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/esrally/utils/io.py b/esrally/utils/io.py index bd3775f01..8626ad07f 100644 --- a/esrally/utils/io.py +++ b/esrally/utils/io.py @@ -47,6 +47,10 @@ import zstandard +# This was introduced in Python 3.11 to `typing` older versions need `typing_extensions` +# but they are treated the same by mypy, so I'm not going to use conditional imports here +from typing_extensions import Self + from esrally.utils import console SUPPORTED_ARCHIVE_FORMATS = [".zip", ".bz2", ".gz", ".tar", ".tar.gz", ".tgz", ".tar.bz2", ".zst"] @@ -63,7 +67,7 @@ def __init__(self, file_name: str, mode: str, encoding: str = "utf-8"): self.encoding = encoding self.f: Optional[IO[AnyStr]] = None - def open(self) -> "FileSource": + def open(self) -> Self: self.f = open(self.file_name, mode=self.mode, encoding=self.encoding) # allow for chaining return self @@ -96,7 +100,7 @@ def close(self) -> None: self.f.close() self.f = None - def __enter__(self) -> "FileSource": + def __enter__(self) -> Self: self.open() return self @@ -122,7 +126,7 @@ def __init__(self, file_name: str, mode: str, encoding: str = "utf-8"): self.f: Optional[IO[bytes]] = None self.mm: Optional[mmap.mmap] = None - def open(self) -> "MmapSource": + def open(self) -> Self: self.f = open(self.file_name, mode="r+b") self.mm = mmap.mmap(self.f.fileno(), 0, access=mmap.ACCESS_READ) self.mm.madvise(mmap.MADV_SEQUENTIAL) @@ -161,7 +165,7 @@ def close(self) -> None: self.f.close() self.f = None - def __enter__(self) -> "MmapSource": + def __enter__(self) -> Self: self.open() return self @@ -205,7 +209,7 @@ def __init__(self, contents: Sequence[str], mode: str, encoding: str = "utf-8"): self.current_index = 0 self.opened = False - def open(self) -> "StringAsFileSource": + def open(self) -> Self: self.opened = True return self @@ -243,7 +247,7 @@ def close(self) -> None: def _assert_opened(self) -> None: assert self.opened - def __enter__(self) -> "StringAsFileSource": + def __enter__(self) -> Self: self.open() return self @@ -521,7 +525,7 @@ def is_valid(self) -> bool: """ return self.exists() and os.path.getmtime(self.offset_table_path) >= os.path.getmtime(self.data_file_path) - def __enter__(self) -> "FileOffsetTable": + def __enter__(self) -> Self: self.offset_file = open(self.offset_table_path, self.mode) return self @@ -569,7 +573,7 @@ def __exit__( return False @classmethod - def create_for_data_file(cls, data_file_path: str) -> "FileOffsetTable": + def create_for_data_file(cls, data_file_path: str) -> Self: """ Factory method to create a new file offset table. @@ -578,7 +582,7 @@ def create_for_data_file(cls, data_file_path: str) -> "FileOffsetTable": return cls(data_file_path, f"{data_file_path}.offset", "wt") @classmethod - def read_for_data_file(cls, data_file_path: str) -> "FileOffsetTable": + def read_for_data_file(cls, data_file_path: str) -> Self: """ Factory method to read from an existing file offset table. diff --git a/pyproject.toml b/pyproject.toml index 3592da9c0..c7f0988cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,6 +120,7 @@ develop = [ "types-urllib3==1.26.19", "types-requests<2.32.0", "types-jsonschema==3.2.0", + "typing-extensions==4.12.2", ] [project.scripts] From 60c8f036d6a3aa53ab56584e7833021d222ced37 Mon Sep 17 00:00:00 2001 From: favilo Date: Mon, 8 Jul 2024 14:33:41 -0700 Subject: [PATCH 25/26] Moving typing-extensions to dependencies, since it is required --- esrally/utils/io.py | 2 +- pyproject.toml | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/esrally/utils/io.py b/esrally/utils/io.py index 8626ad07f..f29522c79 100644 --- a/esrally/utils/io.py +++ b/esrally/utils/io.py @@ -47,7 +47,7 @@ import zstandard -# This was introduced in Python 3.11 to `typing` older versions need `typing_extensions` +# This was introduced in Python 3.11 to `typing`; older versions need `typing_extensions` # but they are treated the same by mypy, so I'm not going to use conditional imports here from typing_extensions import Self diff --git a/pyproject.toml b/pyproject.toml index c7f0988cf..98694685c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,8 @@ dependencies = [ "google-auth==1.22.1", # License: BSD "zstandard==0.21.0", + # License: Python Software Foundation License + "typing-extensions==4.12.2", ] [project.optional-dependencies] @@ -120,7 +122,6 @@ develop = [ "types-urllib3==1.26.19", "types-requests<2.32.0", "types-jsonschema==3.2.0", - "typing-extensions==4.12.2", ] [project.scripts] From f5125d262bb9bd3f6355d34110cae4cd3d7235d7 Mon Sep 17 00:00:00 2001 From: Grzegorz Banasiak Date: Tue, 9 Jul 2024 10:42:12 +0200 Subject: [PATCH 26/26] Change FileOffsetTable file object type from bytes to str --- esrally/utils/io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/esrally/utils/io.py b/esrally/utils/io.py index f29522c79..e4e83b2cd 100644 --- a/esrally/utils/io.py +++ b/esrally/utils/io.py @@ -511,7 +511,7 @@ def __init__(self, data_file_path: str, offset_table_path: str, mode: str): self.data_file_path = data_file_path self.offset_table_path = offset_table_path self.mode = mode - self.offset_file: Optional[IO[bytes]] = None + self.offset_file: Optional[IO[str]] = None def exists(self) -> bool: """ @@ -555,7 +555,7 @@ def find_closest_offset(self, target_line_number: int) -> Tuple[int, int]: assert self.offset_file is not None, "File offset table must be opened in a context manager block." for line in self.offset_file: - line_number, offset_in_bytes = (int(i) for i in line.strip().split(b";")) + line_number, offset_in_bytes = (int(i) for i in line.strip().split(";")) if line_number <= target_line_number: prior_offset = offset_in_bytes prior_remaining_lines = target_line_number - line_number