From 3b5eee2b892743f4124153fab34b12e9317ec872 Mon Sep 17 00:00:00 2001 From: Daniel Mitterdorfer Date: Sun, 29 Mar 2020 12:13:34 +0200 Subject: [PATCH] Add an asyncio-based load generator (#935) With this commit we add a new experimental subcommand `race-aync` to Rally. It allows to specify significantly more clients than the current `race` subcommand. The reason for this is that under the hood, `race-async` uses `asyncio` and runs all clients in a single event loop. Contrary to that, `race` uses an actor system under the hood and maps each client to one process. As the new subcommand is very experimental and not yet meant to be used broadly, there is no accompanying user documentation in this PR. Instead, we plan to build on top of this PR and expand the load generator to take advantage of multiple cores before we consider this usable in production (it will likely keep its experimental status though). In this PR we also implement a compatibility layer into the current load generator so both work internally now with `asyncio`. Consequently, we have already adapted all Rally tracks with a backwards-compatibility layer (see elastic/rally-tracks#97 and elastic/rally-eventdata-track#80). Closes #852 Relates #916 --- create-notice.sh | 7 + docs/adding_tracks.rst | 46 +- docs/migrate.rst | 49 ++ docs/track.rst | 8 + esrally/async_connection.py | 134 ++++ esrally/client.py | 93 +++ esrally/driver/__init__.py | 3 + esrally/driver/async_driver.py | 370 +++++++++ esrally/driver/driver.py | 179 +++-- esrally/driver/runner.py | 560 +++++++------ esrally/metrics.py | 11 +- esrally/racecontrol.py | 182 +++-- esrally/rally.py | 11 +- esrally/reporter.py | 8 + esrally/track/loader.py | 7 +- esrally/track/params.py | 10 +- esrally/utils/io.py | 61 ++ integration-test.sh | 4 +- setup.cfg | 3 + setup.py | 11 +- tests/__init__.py | 32 + tests/driver/async_driver_test.py | 213 +++++ tests/driver/driver_test.py | 388 +++++---- tests/driver/runner_test.py | 1231 ++++++++++++++++++----------- tests/track/params_test.py | 134 ++-- 25 files changed, 2656 insertions(+), 1099 deletions(-) create mode 100644 esrally/async_connection.py create mode 100644 esrally/driver/async_driver.py create mode 100644 tests/driver/async_driver_test.py diff --git a/create-notice.sh b/create-notice.sh index f21519752..03a5d381b 100755 --- a/create-notice.sh +++ b/create-notice.sh @@ -43,6 +43,7 @@ function main { printf "The source code can be obtained at https://github.com/certifi/python-certifi\n" >> "${OUTPUT_FILE}" add_license "certifi" "https://raw.githubusercontent.com/certifi/python-certifi/master/LICENSE" add_license "elasticsearch" "https://raw.githubusercontent.com/elastic/elasticsearch-py/master/LICENSE" + add_license "elasticsearch-async" "https://raw.githubusercontent.com/elastic/elasticsearch-py-async/master/LICENSE" add_license "jinja2" "https://raw.githubusercontent.com/pallets/jinja/master/LICENSE.rst" add_license "jsonschema" "https://raw.githubusercontent.com/Julian/jsonschema/master/COPYING" add_license "psutil" "https://raw.githubusercontent.com/giampaolo/psutil/master/LICENSE" @@ -50,12 +51,18 @@ function main { add_license "tabulate" "https://bitbucket.org/astanin/python-tabulate/raw/03182bf9b8a2becbc54d17aa7e3e7dfed072c5f5/LICENSE" add_license "thespian" "https://raw.githubusercontent.com/kquick/Thespian/master/LICENSE.txt" add_license "boto3" "https://raw.githubusercontent.com/boto/boto3/develop/LICENSE" + add_license "yappi" "https://raw.githubusercontent.com/sumerc/yappi/master/LICENSE" + add_license "ijson" "https://raw.githubusercontent.com/ICRAR/ijson/master/LICENSE.txt" # transitive dependencies # Jinja2 -> Markupsafe add_license "Markupsafe" "https://raw.githubusercontent.com/pallets/markupsafe/master/LICENSE.rst" # elasticsearch -> urllib3 add_license "urllib3" "https://raw.githubusercontent.com/shazow/urllib3/master/LICENSE.txt" + #elasticsearch_async -> aiohttp + add_license "aiohttp" "https://raw.githubusercontent.com/aio-libs/aiohttp/master/LICENSE.txt" + #elasticsearch_async -> async_timeout + add_license "async_timeout" "https://raw.githubusercontent.com/aio-libs/async-timeout/master/LICENSE" # boto3 -> s3transfer add_license "s3transfer" "https://raw.githubusercontent.com/boto/s3transfer/develop/LICENSE.txt" # boto3 -> jmespath diff --git a/docs/adding_tracks.rst b/docs/adding_tracks.rst index d00d2676f..8ab263dfd 100644 --- a/docs/adding_tracks.rst +++ b/docs/adding_tracks.rst @@ -881,17 +881,15 @@ In ``track.json`` set the ``operation-type`` to "percolate" (you can choose this Then create a file ``track.py`` next to ``track.json`` and implement the following two functions:: - def percolate(es, params): - es.percolate( - index="queries", - doc_type="content", - body=params["body"] - ) - + async def percolate(es, params): + await es.percolate( + index="queries", + doc_type="content", + body=params["body"] + ) def register(registry): - registry.register_runner("percolate", percolate) - + registry.register_runner("percolate", percolate, async_runner=True) The function ``percolate`` is the actual runner and takes the following parameters: @@ -906,11 +904,25 @@ This function can return: Similar to a parameter source you also need to bind the name of your operation type to the function within ``register``. +To illustrate how to use custom return values, suppose we want to implement a custom runner that calls the `pending tasks API `_ and returns the number of pending tasks as additional meta-data:: + + async def pending_tasks(es, params): + response = await es.cluster.pending_tasks() + return { + "weight": 1, + "unit": "ops", + "pending-tasks-count": len(response["tasks"]) + } + + def register(registry): + registry.register_runner("pending-tasks", pending_tasks, async_runner=True) + + If you need more control, you can also implement a runner class. The example above, implemented as a class looks as follows:: class PercolateRunner: - def __call__(self, es, params): - es.percolate( + async def __call__(self, es, params): + await es.percolate( index="queries", doc_type="content", body=params["body"] @@ -920,10 +932,12 @@ If you need more control, you can also implement a runner class. The example abo return "percolate" def register(registry): - registry.register_runner("percolate", PercolateRunner()) + registry.register_runner("percolate", PercolateRunner(), async_runner=True) + +The actual runner is implemented in the method ``__call__`` and the same return value conventions apply as for functions. For debugging purposes you should also implement ``__repr__`` and provide a human-readable name for your runner. Finally, you need to register your runner in the ``register`` function. -The actual runner is implemented in the method ``__call__`` and the same return value conventions apply as for functions. For debugging purposes you should also implement ``__repr__`` and provide a human-readable name for your runner. Finally, you need to register your runner in the ``register`` function. Runners also support Python's `context manager `_ interface. Rally uses a new context for each request. Implementing the context manager interface can be handy for cleanup of resources after executing an operation. Rally uses it, for example, to clear open scrolls. +Runners also support Python's `asynchronous context manager `_ interface. Rally uses a new context for each request. Implementing the asynchronous context manager interface can be handy for cleanup of resources after executing an operation. Rally uses it, for example, to clear open scrolls. If you have specified multiple Elasticsearch clusters using :ref:`target-hosts ` you can make Rally pass a dictionary of client connections instead of one for the ``default`` cluster in the ``es`` parameter. @@ -938,14 +952,14 @@ Example (assuming Rally has been invoked specifying ``default`` and ``remote`` i class CreateIndexInRemoteCluster: multi_cluster = True - def __call__(self, es, params): - es['remote'].indices.create(index='remote-index') + async def __call__(self, es, params): + await es["remote"].indices.create(index="remote-index") def __repr__(self, *args, **kwargs): return "create-index-in-remote-cluster" def register(registry): - registry.register_runner("create-index-in-remote-cluster", CreateIndexInRemoteCluster()) + registry.register_runner("create-index-in-remote-cluster", CreateIndexInRemoteCluster(), async_runner=True) .. note:: diff --git a/docs/migrate.rst b/docs/migrate.rst index 2e55a7f77..b020e2ce2 100644 --- a/docs/migrate.rst +++ b/docs/migrate.rst @@ -9,6 +9,55 @@ Minimum Python version is 3.8.0 Rally 1.5.0 requires Python 3.8.0. Check the :ref:`updated installation instructions ` for more details. +Meta-Data for queries are omitted +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Rally 1.5.0 does not determine query meta-data anymore by default to reduce the risk of client-side bottlenecks. The following meta-data fields are affected: + +* ``hits`` +* ``hits_relation`` +* ``timed_out`` +* ``took`` + +If you still want to retrieve them (risking skewed results due to additional overhead), set the new property ``detailed-results`` to ``true`` for any operation of type ``search``. + +Runner API uses asyncio +^^^^^^^^^^^^^^^^^^^^^^^ + +In order to support more concurrent clients in the future, Rally is moving from a synchronous model to an asynchronous model internally. With Rally 1.5.0 all custom runners need to be implemented using async APIs and a new bool argument ``async_runner=True`` needs to be provided upon registration. Below is an example how to migrate a custom runner function. + +A custom runner prior to Rally 1.5.0:: + + def percolate(es, params): + es.percolate( + index="queries", + doc_type="content", + body=params["body"] + ) + + def register(registry): + registry.register_runner("percolate", percolate) + +With Rally 1.5.0, the implementation changes as follows:: + + async def percolate(es, params): + await es.percolate( + index="queries", + doc_type="content", + body=params["body"] + ) + + def register(registry): + registry.register_runner("percolate", percolate, async_runner=True) + +Apply to the following changes for each custom runner: + +* Prefix the function signature with ``async``. +* Add an ``await`` keyword before each Elasticsearch API call. +* Add ``async_runner=True`` as the last argument to the ``register_runner`` function. + +For more details please refer to the updated documentation on :ref:`custom runners `. + Migrating to Rally 1.4.1 ------------------------ diff --git a/docs/track.rst b/docs/track.rst index 890361d05..46e815b99 100644 --- a/docs/track.rst +++ b/docs/track.rst @@ -402,9 +402,17 @@ With the operation type ``search`` you can execute `request body searches = self.interval: + self.current = 0 + self.fn() + + def __repr__(self): + return "timer task for {} firing every {}s.".format(str(self.fn), self.interval) + + def __init__(self, wakeup_interval=1): + """ + :param wakeup_interval: The interval in seconds in which the timer will check whether it has been stopped or + schedule tasks. Default: 1 second. + """ + self.stop_event = threading.Event() + self.tasks = [] + self.wakeup_interval = wakeup_interval + self.logger = logging.getLogger(__name__) + + def add_task(self, fn, interval): + self.tasks.append(Timer.Task(fn, interval, self.wakeup_interval)) + + def stop(self): + self.stop_event.set() + + def __call__(self, *args, **kwargs): + while not self.stop_event.is_set(): + for t in self.tasks: + self.logger.debug("Invoking [%s]", t) + t.may_run() + # allow early exit even if a longer sleeping period is requested + if self.stop_event.is_set(): + self.logger.debug("Stopping timer due to external event.") + break + time.sleep(self.wakeup_interval) + + +class AsyncDriver: + def __init__(self, config, track, challenge, es_client_factory_class=client.EsClientFactory): + self.logger = logging.getLogger(__name__) + self.config = config + self.track = track + self.challenge = challenge + self.es_client_factory = es_client_factory_class + self.metrics_store = None + + self.progress_reporter = console.progress() + self.throughput_calculator = driver.ThroughputCalculator() + self.raw_samples = [] + self.most_recent_sample_per_client = {} + + self.current_tasks = [] + + self.telemetry = None + self.es_clients = None + + self.quiet = self.config.opts("system", "quiet.mode", mandatory=False, default_value=False) + # TODO: Change the default value to `False` once this implementation becomes the default + self.debug_event_loop = self.config.opts("system", "async.debug", mandatory=False, default_value=True) + self.abort_on_error = self.config.opts("driver", "on.error") == "abort" + self.profiling_enabled = self.config.opts("driver", "profiling") + self.sampler = None + + def create_es_clients(self, sync=True): + all_hosts = self.config.opts("client", "hosts").all_hosts + es = {} + for cluster_name, cluster_hosts in all_hosts.items(): + all_client_options = self.config.opts("client", "options").all_client_options + cluster_client_options = dict(all_client_options[cluster_name]) + # Use retries to avoid aborts on long living connections for telemetry devices + cluster_client_options["retry-on-timeout"] = True + + client_factory = self.es_client_factory(cluster_hosts, cluster_client_options) + if sync: + es[cluster_name] = client_factory.create() + else: + es[cluster_name] = client_factory.create_async() + return es + + def prepare_telemetry(self): + enabled_devices = self.config.opts("telemetry", "devices") + telemetry_params = self.config.opts("telemetry", "params") + + es = self.es_clients + es_default = self.es_clients["default"] + self.telemetry = telemetry.Telemetry(enabled_devices, devices=[ + telemetry.NodeStats(telemetry_params, es, self.metrics_store), + telemetry.ExternalEnvironmentInfo(es_default, self.metrics_store), + telemetry.ClusterEnvironmentInfo(es_default, self.metrics_store), + telemetry.JvmStatsSummary(es_default, self.metrics_store), + telemetry.IndexStats(es_default, self.metrics_store), + telemetry.MlBucketProcessingTime(es_default, self.metrics_store), + telemetry.CcrStats(telemetry_params, es, self.metrics_store), + telemetry.RecoveryStats(telemetry_params, es, self.metrics_store) + ]) + + def wait_for_rest_api(self): + skip_rest_api_check = self.config.opts("mechanic", "skip.rest.api.check") + if skip_rest_api_check: + self.logger.info("Skipping REST API check.") + else: + es_default = self.es_clients["default"] + self.logger.info("Checking if REST API is available.") + if client.wait_for_rest_layer(es_default, max_attempts=40): + self.logger.info("REST API is available.") + else: + self.logger.error("REST API layer is not yet available. Stopping benchmark.") + raise exceptions.SystemSetupError("Elasticsearch REST API layer is not available.") + + def retrieve_cluster_info(self): + # noinspection PyBroadException + try: + return self.es_clients["default"].info() + except BaseException: + self.logger.exception("Could not retrieve cluster info on benchmark start") + return None + + def setup(self): + if self.track.has_plugins: + # no need to fetch the track once more; it has already been updated + track.track_repo(self.config, fetch=False, update=False) + # load track plugins eagerly to initialize the respective parameter sources + track.load_track_plugins(self.config, runner.register_runner, scheduler.register_scheduler) + track.prepare_track(self.track, self.config) + + self.metrics_store = metrics.metrics_store(cfg=self.config, + track=self.track.name, + challenge=self.challenge.name, + read_only=False) + self.es_clients = self.create_es_clients() + self.wait_for_rest_api() + self.prepare_telemetry() + + cluster_info = self.retrieve_cluster_info() + cluster_version = cluster_info["version"] if cluster_info else {} + return cluster_version.get("build_flavor", "oss"), cluster_version.get("number"), cluster_version.get("build_hash") + + def run(self): + self.logger.info("Benchmark is about to start.") + # ensure relative time starts when the benchmark starts. + self.reset_relative_time() + self.logger.info("Attaching cluster-level telemetry devices.") + self.telemetry.on_benchmark_start() + self.logger.info("Cluster-level telemetry devices are now attached.") + + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + timer = Timer() + timer.add_task(fn=self.update_samples, interval=1) + timer.add_task(fn=self.post_process_samples, interval=30) + timer.add_task(fn=self.update_progress_message, interval=1) + + pool.submit(timer) + + # needed because a new thread (that is not the main thread) does not have an event loop + loop = asyncio.new_event_loop() + loop.set_debug(self.debug_event_loop) + asyncio.set_event_loop(loop) + loop.set_exception_handler(self._logging_exception_handler) + + track.set_absolute_data_path(self.config, self.track) + runner.register_default_runners() + # We can skip this here as long as we run in the same process; it has already been done in #setup() + # if self.track.has_plugins: + # track.load_track_plugins(self.config, runner.register_runner, scheduler.register_scheduler) + try: + benchmark_runner = driver.AsyncProfiler(self._run_benchmark) if self.profiling_enabled else self._run_benchmark + loop.run_until_complete(benchmark_runner()) + self.telemetry.on_benchmark_stop() + self.logger.info("All steps completed.") + return self.metrics_store.to_externalizable() + finally: + self.logger.debug("Stopping timer...") + timer.stop() + pool.shutdown() + self.logger.debug("Closing event loop...") + loop.close() + self.progress_reporter.finish() + self.logger.debug("Closing metrics store...") + self.metrics_store.close() + # immediately clear as we don't need it anymore and it can consume a significant amount of memory + self.metrics_store = None + + def _logging_exception_handler(self, loop, context): + self.logger.error("Uncaught exception in event loop: %s", context) + + async def _run_benchmark(self): + # avoid: aiohttp.internal WARNING The object should be created from async function + es = self.create_es_clients(sync=False) + try: + cancel = threading.Event() + # allow to buffer more events than by default as we expect to have way more clients. + self.sampler = driver.Sampler(start_timestamp=time.perf_counter(), buffer_size=65536) + + for task in self.challenge.schedule: + self.current_tasks = [] + aws = [] + for sub_task in task: + self.current_tasks.append(sub_task) + self.logger.info("Running task [%s] with [%d] clients...", sub_task.name, sub_task.clients) + for client_id in range(sub_task.clients): + schedule = driver.schedule_for(self.track, sub_task, client_id) + # used to indicate that we want to prematurely consider this completed. This is *not* due to + # cancellation but a regular event in a benchmark and used to model task dependency of parallel tasks. + complete = threading.Event() + e = driver.AsyncExecutor(client_id, sub_task, schedule, es, self.sampler, cancel, complete, self.abort_on_error) + aws.append(e()) + # join point + _ = await asyncio.gather(*aws) + self.logger.info("All clients have finished running task [%s]", task.name) + # drain the active samples before we move on to the next task + self.update_samples() + self.post_process_samples() + self.reset_relative_time() + self.update_progress_message(task_finished=True) + finally: + await asyncio.get_event_loop().shutdown_asyncgens() + for e in es.values(): + await e.transport.close() + + def reset_relative_time(self): + self.logger.debug("Resetting relative time of request metrics store.") + self.metrics_store.reset_relative_time() + + def update_samples(self): + if self.sampler: + samples = self.sampler.samples + self.logger.info("Adding [%d] new samples.", len(samples)) + if len(samples) > 0: + self.raw_samples += samples + # We need to check all samples, they will be from different clients + for s in samples: + self.most_recent_sample_per_client[s.client_id] = s + self.logger.info("Done adding [%d] new samples.", len(samples)) + else: + self.logger.info("No sampler defined yet. Skipping update of samples.") + + def update_progress_message(self, task_finished=False): + if not self.quiet and len(self.current_tasks) > 0: + tasks = ",".join([t.name for t in self.current_tasks]) + + if task_finished: + total_progress = 1.0 + else: + # we only count clients which actually contribute to progress. If clients are executing tasks eternally in a parallel + # structure, we should not count them. The reason is that progress depends entirely on the client(s) that execute the + # task that is completing the parallel structure. + progress_per_client = [s.percent_completed + for s in self.most_recent_sample_per_client.values() if s.percent_completed is not None] + + num_clients = max(len(progress_per_client), 1) + total_progress = sum(progress_per_client) / num_clients + self.progress_reporter.print("Running %s" % tasks, "[%3d%% done]" % (round(total_progress * 100))) + if task_finished: + self.progress_reporter.finish() + + def post_process_samples(self): + if len(self.raw_samples) == 0: + return + total_start = time.perf_counter() + start = total_start + # we do *not* do this here to avoid concurrent updates (we are single-threaded) but rather to make it clear that we use + # only a snapshot and that new data will go to a new sample set. + raw_samples = self.raw_samples + self.raw_samples = [] + for sample in raw_samples: + meta_data = self.merge( + self.track.meta_data, + self.challenge.meta_data, + sample.operation.meta_data, + sample.task.meta_data, + sample.request_meta_data) + + self.metrics_store.put_value_cluster_level(name="latency", value=sample.latency_ms, unit="ms", task=sample.task.name, + operation=sample.operation.name, operation_type=sample.operation.type, + sample_type=sample.sample_type, absolute_time=sample.absolute_time, + relative_time=sample.relative_time, meta_data=meta_data) + + self.metrics_store.put_value_cluster_level(name="service_time", value=sample.service_time_ms, unit="ms", task=sample.task.name, + operation=sample.task.name, operation_type=sample.operation.type, + sample_type=sample.sample_type, absolute_time=sample.absolute_time, + relative_time=sample.relative_time, meta_data=meta_data) + + self.metrics_store.put_value_cluster_level(name="processing_time", value=sample.processing_time_ms, + unit="ms", task=sample.task.name, + operation=sample.task.name, operation_type=sample.operation.type, + sample_type=sample.sample_type, absolute_time=sample.absolute_time, + relative_time=sample.relative_time, meta_data=meta_data) + + end = time.perf_counter() + self.logger.debug("Storing latency and service time took [%f] seconds.", (end - start)) + start = end + aggregates = self.throughput_calculator.calculate(raw_samples) + end = time.perf_counter() + self.logger.debug("Calculating throughput took [%f] seconds.", (end - start)) + start = end + for task, samples in aggregates.items(): + meta_data = self.merge( + self.track.meta_data, + self.challenge.meta_data, + task.operation.meta_data, + task.meta_data + ) + for absolute_time, relative_time, sample_type, throughput, throughput_unit in samples: + self.metrics_store.put_value_cluster_level(name="throughput", value=throughput, unit=throughput_unit, task=task.name, + operation=task.operation.name, operation_type=task.operation.type, + sample_type=sample_type, absolute_time=absolute_time, + relative_time=relative_time, meta_data=meta_data) + end = time.perf_counter() + self.logger.debug("Storing throughput took [%f] seconds.", (end - start)) + start = end + # this will be a noop for the in-memory metrics store. If we use an ES metrics store however, this will ensure that we already send + # the data and also clear the in-memory buffer. This allows users to see data already while running the benchmark. In cases where + # it does not matter (i.e. in-memory) we will still defer this step until the end. + # + # Don't force refresh here in the interest of short processing times. We don't need to query immediately afterwards so there is + # no need for frequent refreshes. + self.metrics_store.flush(refresh=False) + end = time.perf_counter() + self.logger.debug("Flushing the metrics store took [%f] seconds.", (end - start)) + self.logger.debug("Postprocessing [%d] raw samples took [%f] seconds in total.", len(raw_samples), (end - total_start)) + + def merge(self, *args): + result = {} + for arg in args: + if arg is not None: + result.update(arg) + return result diff --git a/esrally/driver/driver.py b/esrally/driver/driver.py index 25b610f35..b847a5e9f 100644 --- a/esrally/driver/driver.py +++ b/esrally/driver/driver.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import asyncio import concurrent.futures import datetime import logging @@ -601,6 +602,12 @@ def post_process_samples(self): sample_type=sample.sample_type, absolute_time=sample.absolute_time, relative_time=sample.relative_time, meta_data=meta_data) + self.metrics_store.put_value_cluster_level(name="processing_time", value=sample.processing_time_ms, + unit="ms", task=sample.task.name, + operation=sample.task.name, operation_type=sample.operation.type, + sample_type=sample.sample_type, absolute_time=sample.absolute_time, + relative_time=sample.relative_time, meta_data=meta_data) + end = time.perf_counter() self.logger.debug("Storing latency and service time took [%f] seconds.", (end - start)) start = end @@ -772,7 +779,6 @@ def receiveUnrecognizedMessage(self, msg, sender): self.logger.info("LoadGenerator[%d] received unknown message [%s] (ignoring).", self.client_id, str(msg)) def drive(self): - profiling_enabled = self.config.opts("driver", "profiling") task_allocation = self.current_task_and_advance() # skip non-tasks in the task list while task_allocation is None: @@ -799,7 +805,7 @@ def drive(self): self.client_id, task) else: self.logger.info("LoadGenerator[%d] is executing [%s].", self.client_id, task) - self.sampler = Sampler(self.client_id, task, start_timestamp=time.perf_counter()) + self.sampler = Sampler(start_timestamp=time.perf_counter()) # We cannot use the global client index here because we need to support parallel execution of tasks with multiple clients. # # Consider the following scenario: @@ -810,11 +816,10 @@ def drive(self): # Now we need to ensure that we start partitioning parameters correctly in both cases. And that means we need to start # from (client) index 0 in both cases instead of 0 for indexA and 4 for indexB. schedule = schedule_for(self.track, task_allocation.task, task_allocation.client_index_in_task) + executor = AsyncIoAdapter( + self.config, self.client_id, task, schedule, self.sampler, self.cancel, self.complete, self.abort_on_error) - executor = Executor(task, schedule, self.es, self.sampler, self.cancel, self.complete, self.abort_on_error) - final_executor = Profiler(executor, self.client_id, task) if profiling_enabled else executor - - self.executor_future = self.pool.submit(final_executor) + self.executor_future = self.pool.submit(executor) self.wakeupAfter(datetime.timedelta(seconds=self.wakeup_interval)) else: raise exceptions.RallyAssertionError("Unknown task type [%s]" % type(task_allocation)) @@ -841,20 +846,19 @@ class Sampler: Encapsulates management of gathered samples. """ - def __init__(self, client_id, task, start_timestamp): - self.client_id = client_id - self.task = task + def __init__(self, start_timestamp, buffer_size=16384): self.start_timestamp = start_timestamp - self.q = queue.Queue(maxsize=16384) + self.q = queue.Queue(maxsize=buffer_size) self.logger = logging.getLogger(__name__) - def add(self, sample_type, request_meta_data, latency_ms, service_time_ms, total_ops, total_ops_unit, time_period, percent_completed): + def add(self, task, client_id, sample_type, meta_data, latency, service_time, processing_time, ops, ops_unit, + time_period, percent_completed): try: - self.q.put_nowait(Sample(self.client_id, time.time(), time.perf_counter() - self.start_timestamp, self.task, - sample_type, request_meta_data, latency_ms, service_time_ms, total_ops, total_ops_unit, time_period, - percent_completed)) + self.q.put_nowait(Sample(client_id, time.time(), time.perf_counter() - self.start_timestamp, task, + sample_type, meta_data, latency, service_time, processing_time, ops, + ops_unit, time_period, percent_completed)) except queue.Full: - self.logger.warning("Dropping sample for [%s] due to a full sampling queue.", self.task.operation.name) + self.logger.warning("Dropping sample for [%s] due to a full sampling queue.", task.operation.name) @property def samples(self): @@ -868,8 +872,8 @@ def samples(self): class Sample: - def __init__(self, client_id, absolute_time, relative_time, task, sample_type, request_meta_data, latency_ms, service_time_ms, - total_ops, total_ops_unit, time_period, percent_completed): + def __init__(self, client_id, absolute_time, relative_time, task, sample_type, request_meta_data, latency_ms, + service_time_ms, processing_time_ms, total_ops, total_ops_unit, time_period, percent_completed): self.client_id = client_id self.absolute_time = absolute_time self.relative_time = relative_time @@ -878,6 +882,7 @@ def __init__(self, client_id, absolute_time, relative_time, task, sample_type, r self.request_meta_data = request_meta_data self.latency_ms = latency_ms self.service_time_ms = service_time_ms + self.processing_time_ms = processing_time_ms self.total_ops = total_ops self.total_ops_unit = total_ops_unit self.time_period = time_period @@ -1020,41 +1025,90 @@ def calculate(self, samples, bucket_interval_secs=1): return global_throughput -class Profiler: - def __init__(self, target, client_id, task): +class AsyncIoAdapter: + def __init__(self, cfg, client_id, sub_task, schedule, sampler, cancel, complete, abort_on_error): + self.cfg = cfg + self.client_id = client_id + self.sub_task = sub_task + self.schedule = schedule + self.sampler = sampler + self.cancel = cancel + self.complete = complete + self.abort_on_error = abort_on_error + self.profiling_enabled = self.cfg.opts("driver", "profiling") + self.debug_event_loop = self.cfg.opts("system", "async.debug", mandatory=False, default_value=False) + + def __call__(self, *args, **kwargs): + # only possible in Python 3.7+ (has introduced get_running_loop) + # try: + # loop = asyncio.get_running_loop() + # except RuntimeError: + # loop = asyncio.new_event_loop() + # asyncio.set_event_loop(loop) + loop = asyncio.new_event_loop() + loop.set_debug(self.debug_event_loop) + loop.set_exception_handler(self._logging_exception_handler) + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(self.run()) + finally: + loop.close() + + def _logging_exception_handler(self, loop, context): + logging.getLogger(__name__).error("Uncaught exception in event loop: %s", context) + + async def run(self): + def es_clients(all_hosts, all_client_options): + es = {} + for cluster_name, cluster_hosts in all_hosts.items(): + es[cluster_name] = client.EsClientFactory(cluster_hosts, all_client_options[cluster_name]).create_async() + return es + + es = es_clients(self.cfg.opts("client", "hosts").all_hosts, self.cfg.opts("client", "options").all_client_options) + async_executor = AsyncExecutor( + self.client_id, self.sub_task, self.schedule, es, self.sampler, self.cancel, self.complete, self.abort_on_error) + final_executor = AsyncProfiler(async_executor) if self.profiling_enabled else async_executor + try: + return await final_executor() + finally: + await asyncio.get_event_loop().shutdown_asyncgens() + for e in es.values(): + await e.transport.close() + + +class AsyncProfiler: + def __init__(self, target): """ :param target: The actual executor which should be profiled. - :param client_id: The id of the client that executes the operation. - :param task: The task that is executed. """ self.target = target - self.client_id = client_id - self.task = task self.profile_logger = logging.getLogger("rally.profile") - def __call__(self, *args, **kwargs): - import cProfile - import pstats + async def __call__(self, *args, **kwargs): + import yappi import io as python_io - profiler = cProfile.Profile() - profiler.enable() + yappi.start() try: - return self.target(*args, **kwargs) + return await self.target(*args, **kwargs) finally: - profiler.disable() + yappi.stop() s = python_io.StringIO() - sortby = 'cumulative' - ps = pstats.Stats(profiler, stream=s).sort_stats(sortby) - ps.print_stats() - - profile = "\n=== Profile START for client [%s] and task [%s] ===\n" % (str(self.client_id), str(self.task)) + yappi.get_func_stats().print_all(out=s, columns={ + 0: ("name", 140), + 1: ("ncall", 8), + 2: ("tsub", 8), + 3: ("ttot", 8), + 4: ("tavg", 8) + }) + + profile = "\n=== Profile START ===\n" profile += s.getvalue() - profile += "=== Profile END for client [%s] and task [%s] ===" % (str(self.client_id), str(self.task)) + profile += "=== Profile END ===" self.profile_logger.info(profile) -class Executor: - def __init__(self, task, schedule, es, sampler, cancel, complete, abort_on_error=False): +class AsyncExecutor: + def __init__(self, client_id, task, schedule, es, sampler, cancel, complete, abort_on_error=False): """ Executes tasks according to the schedule for a given operation. @@ -1065,6 +1119,7 @@ def __init__(self, task, schedule, es, sampler, cancel, complete, abort_on_error :param cancel: A shared boolean that indicates we need to cancel execution. :param complete: A shared boolean that indicates we need to prematurely complete execution. """ + self.client_id = client_id self.task = task self.op = task.operation self.schedule_handle = schedule @@ -1075,13 +1130,15 @@ def __init__(self, task, schedule, es, sampler, cancel, complete, abort_on_error self.abort_on_error = abort_on_error self.logger = logging.getLogger(__name__) - def __call__(self, *args, **kwargs): + async def __call__(self, *args, **kwargs): total_start = time.perf_counter() # lazily initialize the schedule + self.logger.debug("Initializing schedule for client id [%s].", self.client_id) schedule = self.schedule_handle() + self.logger.debug("Entering main loop for client id [%s].", self.client_id) # noinspection PyBroadException try: - for expected_scheduled_time, sample_type, percent_completed, runner, params in schedule: + async for expected_scheduled_time, sample_type, percent_completed, runner, params in schedule: if self.cancel.is_set(): self.logger.info("User cancelled execution.") break @@ -1090,12 +1147,14 @@ def __call__(self, *args, **kwargs): if throughput_throttled: rest = absolute_expected_schedule_time - time.perf_counter() if rest > 0: - time.sleep(rest) - start = time.perf_counter() - total_ops, total_ops_unit, request_meta_data = execute_single(runner, self.es, params, self.abort_on_error) - stop = time.perf_counter() - - service_time = stop - start + await asyncio.sleep(rest) + request_context = self.es["default"].init_request_context() + processing_start = time.perf_counter() + total_ops, total_ops_unit, request_meta_data = await execute_single(runner, self.es, params, self.abort_on_error) + processing_end = time.perf_counter() + stop = request_context["request_end"] + service_time = request_context["request_end"] - request_context["request_start"] + processing_time = processing_end - processing_start # Do not calculate latency separately when we don't throttle throughput. This metric is just confusing then. latency = stop - absolute_expected_schedule_time if throughput_throttled else service_time # last sample should bump progress to 100% if externally completed. @@ -1106,8 +1165,10 @@ def __call__(self, *args, **kwargs): progress = runner.percent_completed else: progress = percent_completed - self.sampler.add(sample_type, request_meta_data, convert.seconds_to_ms(latency), convert.seconds_to_ms(service_time), - total_ops, total_ops_unit, (stop - total_start), progress) + self.sampler.add(self.task, self.client_id, sample_type, request_meta_data, + convert.seconds_to_ms(latency), convert.seconds_to_ms(service_time), + convert.seconds_to_ms(processing_time), total_ops, total_ops_unit, + (stop - total_start), progress) if completed: self.logger.info("Task is considered completed due to external event.") @@ -1121,7 +1182,7 @@ def __call__(self, *args, **kwargs): self.complete.set() -def execute_single(runner, es, params, abort_on_error=False): +async def execute_single(runner, es, params, abort_on_error=False): """ Invokes the given runner once and provides the runner's return value in a uniform structure. @@ -1129,8 +1190,8 @@ def execute_single(runner, es, params, abort_on_error=False): """ import elasticsearch try: - with runner: - return_value = runner(es, params) + async with runner: + return_value = await runner(es, params) if isinstance(return_value, tuple) and len(return_value) == 2: total_ops, total_ops_unit = return_value request_meta_data = {"success": True} @@ -1154,7 +1215,10 @@ def execute_single(runner, es, params, abort_on_error=False): # The ES client will sometimes return string like "N/A" or "TIMEOUT" for connection errors. if isinstance(e.status_code, int): request_meta_data["http-status"] = e.status_code - if e.info: + # connection timeout errors don't provide a helpful description + if isinstance(e, elasticsearch.ConnectionTimeout): + request_meta_data["error-description"] = "network connection timed out" + elif e.info: request_meta_data["error-description"] = "%s (%s)" % (e.error, e.info) else: request_meta_data["error-description"] = e.error @@ -1169,7 +1233,6 @@ def execute_single(runner, es, params, abort_on_error=False): if description: msg += ", Description: %s" % description raise exceptions.RallyAssertionError(msg) - return total_ops, total_ops_unit, request_meta_data @@ -1404,10 +1467,14 @@ def __init__(self, task_name, sched, task_progress_control, runner, params): self.runner = runner self.params = params self.logger = logging.getLogger(__name__) + # TODO: Can we offload the parameter source execution to a different thread / process? Is this too heavy-weight? + #from concurrent.futures import ThreadPoolExecutor + #import asyncio + #self.io_pool_exc = ThreadPoolExecutor(max_workers=1) + #self.loop = asyncio.get_event_loop() - def __call__(self): + async def __call__(self): next_scheduled = 0 - if self.task_progress_control.infinite: self.logger.info("Parameter source will determine when the schedule for [%s] terminates.", self.task_name) param_source_knows_progress = hasattr(self.params, "percent_completed") @@ -1416,6 +1483,7 @@ def __call__(self): try: # does not contribute at all to completion. Hence, we cannot define completion. percent_completed = self.params.percent_completed if param_source_knows_progress else None + #current_params = await self.loop.run_in_executor(self.io_pool_exc, self.params.params) yield (next_scheduled, self.task_progress_control.sample_type, percent_completed, self.runner, self.params.params()) next_scheduled = self.sched.next(next_scheduled) @@ -1430,6 +1498,7 @@ def __call__(self): str(self.task_progress_control), self.task_name) while not self.task_progress_control.completed: try: + #current_params = await self.loop.run_in_executor(self.io_pool_exc, self.params.params) yield (next_scheduled, self.task_progress_control.sample_type, self.task_progress_control.percent_completed, diff --git a/esrally/driver/runner.py b/esrally/driver/runner.py index a36fca427..6f70e4c58 100644 --- a/esrally/driver/runner.py +++ b/esrally/driver/runner.py @@ -15,14 +15,17 @@ # specific language governing permissions and limitations # under the License. +import asyncio +import json import logging import random import sys -import time import types from collections import Counter, OrderedDict from copy import deepcopy +import ijson + from esrally import exceptions, track # Mapping from operation type to specific runner @@ -30,37 +33,37 @@ def register_default_runners(): - register_runner(track.OperationType.Bulk.name, BulkIndex()) - register_runner(track.OperationType.ForceMerge.name, ForceMerge()) - register_runner(track.OperationType.IndicesStats.name, Retry(IndicesStats())) - register_runner(track.OperationType.NodesStats.name, NodeStats()) - register_runner(track.OperationType.Search.name, Query()) - register_runner(track.OperationType.RawRequest.name, RawRequest()) + register_runner(track.OperationType.Bulk.name, BulkIndex(), async_runner=True) + register_runner(track.OperationType.ForceMerge.name, ForceMerge(), async_runner=True) + register_runner(track.OperationType.IndicesStats.name, Retry(IndicesStats()), async_runner=True) + register_runner(track.OperationType.NodesStats.name, NodeStats(), async_runner=True) + register_runner(track.OperationType.Search.name, Query(), async_runner=True) + register_runner(track.OperationType.RawRequest.name, RawRequest(), async_runner=True) # This is an administrative operation but there is no need for a retry here as we don't issue a request - register_runner(track.OperationType.Sleep.name, Sleep()) + register_runner(track.OperationType.Sleep.name, Sleep(), async_runner=True) # these requests should not be retried as they are not idempotent - register_runner(track.OperationType.RestoreSnapshot.name, RestoreSnapshot()) + register_runner(track.OperationType.RestoreSnapshot.name, RestoreSnapshot(), async_runner=True) # We treat the following as administrative commands and thus already start to wrap them in a retry. - register_runner(track.OperationType.ClusterHealth.name, Retry(ClusterHealth())) - register_runner(track.OperationType.PutPipeline.name, Retry(PutPipeline())) - register_runner(track.OperationType.Refresh.name, Retry(Refresh())) - register_runner(track.OperationType.CreateIndex.name, Retry(CreateIndex())) - register_runner(track.OperationType.DeleteIndex.name, Retry(DeleteIndex())) - register_runner(track.OperationType.CreateIndexTemplate.name, Retry(CreateIndexTemplate())) - register_runner(track.OperationType.DeleteIndexTemplate.name, Retry(DeleteIndexTemplate())) - register_runner(track.OperationType.ShrinkIndex.name, Retry(ShrinkIndex())) - register_runner(track.OperationType.CreateMlDatafeed.name, Retry(CreateMlDatafeed())) - register_runner(track.OperationType.DeleteMlDatafeed.name, Retry(DeleteMlDatafeed())) - register_runner(track.OperationType.StartMlDatafeed.name, Retry(StartMlDatafeed())) - register_runner(track.OperationType.StopMlDatafeed.name, Retry(StopMlDatafeed())) - register_runner(track.OperationType.CreateMlJob.name, Retry(CreateMlJob())) - register_runner(track.OperationType.DeleteMlJob.name, Retry(DeleteMlJob())) - register_runner(track.OperationType.OpenMlJob.name, Retry(OpenMlJob())) - register_runner(track.OperationType.CloseMlJob.name, Retry(CloseMlJob())) - register_runner(track.OperationType.DeleteSnapshotRepository.name, Retry(DeleteSnapshotRepository())) - register_runner(track.OperationType.CreateSnapshotRepository.name, Retry(CreateSnapshotRepository())) - register_runner(track.OperationType.WaitForRecovery.name, Retry(IndicesRecovery())) - register_runner(track.OperationType.PutSettings.name, Retry(PutSettings())) + register_runner(track.OperationType.ClusterHealth.name, Retry(ClusterHealth()), async_runner=True) + register_runner(track.OperationType.PutPipeline.name, Retry(PutPipeline()), async_runner=True) + register_runner(track.OperationType.Refresh.name, Retry(Refresh()), async_runner=True) + register_runner(track.OperationType.CreateIndex.name, Retry(CreateIndex()), async_runner=True) + register_runner(track.OperationType.DeleteIndex.name, Retry(DeleteIndex()), async_runner=True) + register_runner(track.OperationType.CreateIndexTemplate.name, Retry(CreateIndexTemplate()), async_runner=True) + register_runner(track.OperationType.DeleteIndexTemplate.name, Retry(DeleteIndexTemplate()), async_runner=True) + register_runner(track.OperationType.ShrinkIndex.name, Retry(ShrinkIndex()), async_runner=True) + register_runner(track.OperationType.CreateMlDatafeed.name, Retry(CreateMlDatafeed()), async_runner=True) + register_runner(track.OperationType.DeleteMlDatafeed.name, Retry(DeleteMlDatafeed()), async_runner=True) + register_runner(track.OperationType.StartMlDatafeed.name, Retry(StartMlDatafeed()), async_runner=True) + register_runner(track.OperationType.StopMlDatafeed.name, Retry(StopMlDatafeed()), async_runner=True) + register_runner(track.OperationType.CreateMlJob.name, Retry(CreateMlJob()), async_runner=True) + register_runner(track.OperationType.DeleteMlJob.name, Retry(DeleteMlJob()), async_runner=True) + register_runner(track.OperationType.OpenMlJob.name, Retry(OpenMlJob()), async_runner=True) + register_runner(track.OperationType.CloseMlJob.name, Retry(CloseMlJob()), async_runner=True) + register_runner(track.OperationType.DeleteSnapshotRepository.name, Retry(DeleteSnapshotRepository()), async_runner=True) + register_runner(track.OperationType.CreateSnapshotRepository.name, Retry(CreateSnapshotRepository()), async_runner=True) + register_runner(track.OperationType.WaitForRecovery.name, Retry(IndicesRecovery()), async_runner=True) + register_runner(track.OperationType.PutSettings.name, Retry(PutSettings()), async_runner=True) def runner_for(operation_type): @@ -70,10 +73,15 @@ def runner_for(operation_type): raise exceptions.RallyError("No runner available for operation type [%s]" % operation_type) -def register_runner(operation_type, runner): +def register_runner(operation_type, runner, **kwargs): logger = logging.getLogger(__name__) + async_runner = kwargs.get("async_runner", False) + if not async_runner: + raise exceptions.RallyAssertionError( + "Runner [{}] must be implemented as async runner and registered with async_runner=True.".format(str(runner))) + if getattr(runner, "multi_cluster", False) == True: - if "__enter__" in dir(runner) and "__exit__" in dir(runner): + if "__aenter__" in dir(runner) and "__aexit__" in dir(runner): if logger.isEnabledFor(logging.DEBUG): logger.debug("Registering runner object [%s] for [%s].", str(runner), str(operation_type)) __RUNNERS[operation_type] = _multi_cluster_runner(runner, str(runner), context_manager_enabled=True) @@ -86,7 +94,7 @@ def register_runner(operation_type, runner): if logger.isEnabledFor(logging.DEBUG): logger.debug("Registering runner function [%s] for [%s].", str(runner), str(operation_type)) __RUNNERS[operation_type] = _single_cluster_runner(runner, runner.__name__) - elif "__enter__" in dir(runner) and "__exit__" in dir(runner): + elif "__aenter__" in dir(runner) and "__aexit__" in dir(runner): if logger.isEnabledFor(logging.DEBUG): logger.debug("Registering context-manager capable runner object [%s] for [%s].", str(runner), str(operation_type)) __RUNNERS[operation_type] = _single_cluster_runner(runner, str(runner), context_manager_enabled=True) @@ -110,10 +118,10 @@ def __init__(self, *args, **kwargs): super(Runner, self).__init__(*args, **kwargs) self.logger = logging.getLogger(__name__) - def __enter__(self): + async def __aenter__(self): return self - def __call__(self, *args): + async def __call__(self, *args): """ Runs the actual method that should be benchmarked. @@ -125,7 +133,7 @@ def __call__(self, *args): """ raise NotImplementedError("abstract operation") - def __exit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type, exc_val, exc_tb): return False @@ -184,18 +192,18 @@ def completed(self): def percent_completed(self): return None - def __call__(self, *args): - return self.delegate(*args) + async def __call__(self, *args): + return await self.delegate(*args) def __repr__(self, *args, **kwargs): return repr(self.delegate) - def __enter__(self): - self.delegate.__enter__() + async def __aenter__(self): + await self.delegate.__aenter__() return self - def __exit__(self, exc_type, exc_val, exc_tb): - return self.delegate.__exit__(exc_type, exc_val, exc_tb) + async def __aexit__(self, exc_type, exc_val, exc_tb): + return await self.delegate.__aexit__(exc_type, exc_val, exc_tb) class WithCompletion(Runner, Delegator): @@ -211,18 +219,18 @@ def completed(self): def percent_completed(self): return self.progressable.percent_completed - def __call__(self, *args): - return self.delegate(*args) + async def __call__(self, *args): + return await self.delegate(*args) def __repr__(self, *args, **kwargs): return repr(self.delegate) - def __enter__(self): - self.delegate.__enter__() + async def __aenter__(self): + await self.delegate.__aenter__() return self - def __exit__(self, exc_type, exc_val, exc_tb): - return self.delegate.__exit__(exc_type, exc_val, exc_tb) + async def __aexit__(self, exc_type, exc_val, exc_tb): + return await self.delegate.__aexit__(exc_type, exc_val, exc_tb) class MultiClientRunner(Runner, Delegator): @@ -232,8 +240,8 @@ def __init__(self, runnable, name, client_extractor, context_manager_enabled=Fal self.client_extractor = client_extractor self.context_manager_enabled = context_manager_enabled - def __call__(self, *args): - return self.delegate(self.client_extractor(args[0]), *args[1:]) + async def __call__(self, *args): + return await self.delegate(self.client_extractor(args[0]), *args[1:]) def __repr__(self, *args, **kwargs): if self.context_manager_enabled: @@ -241,14 +249,14 @@ def __repr__(self, *args, **kwargs): else: return "user-defined runner for [%s]" % self.name - def __enter__(self): + async def __aenter__(self): if self.context_manager_enabled: - self.delegate.__enter__() + await self.delegate.__aenter__() return self - def __exit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type, exc_val, exc_tb): if self.context_manager_enabled: - return self.delegate.__exit__(exc_type, exc_val, exc_tb) + return await self.delegate.__aexit__(exc_type, exc_val, exc_tb) else: return False @@ -269,7 +277,7 @@ class BulkIndex(Runner): def __init__(self): super().__init__() - def __call__(self, es, params): + async def __call__(self, es, params): """ Runs one bulk indexing operation. @@ -442,11 +450,16 @@ def __call__(self, es, params): with_action_metadata = mandatory(params, "action-metadata-present", self) bulk_size = mandatory(params, "bulk-size", self) + # parse responses lazily in the standard case - responses might be large thus parsing skews results and if no + # errors have occurred we only need a small amount of information from the potentially large response. + if not detailed_results: + es.return_raw_response() + if with_action_metadata: # only half of the lines are documents - response = es.bulk(body=params["body"], params=bulk_params) + response = await es.bulk(body=params["body"], params=bulk_params) else: - response = es.bulk(body=params["body"], index=index, doc_type=params.get("type"), params=bulk_params) + response = await es.bulk(body=params["body"], index=index, doc_type=params.get("type"), params=bulk_params) stats = self.detailed_stats(params, bulk_size, response) if detailed_results else self.simple_stats(bulk_size, response) @@ -529,20 +542,23 @@ def detailed_stats(self, params, bulk_size, response): def simple_stats(self, bulk_size, response): bulk_error_count = 0 error_details = set() - if response["errors"]: - for idx, item in enumerate(response["items"]): + # parse lazily on the fast path + props = parse(response, ["errors", "took"]) + + if props.get("errors", False): + # Reparse fully in case of errors - this will be slower + parsed_response = json.loads(response.getvalue()) + for idx, item in enumerate(parsed_response["items"]): data = next(iter(item.values())) if data["status"] > 299 or ('_shards' in data and data["_shards"]["failed"] > 0): bulk_error_count += 1 self.extract_error_details(error_details, data) stats = { - "took": response.get("took"), + "took": props.get("took"), "success": bulk_error_count == 0, "success-count": bulk_size - bulk_error_count, "error-count": bulk_error_count } - if "ingest_took" in response: - stats["ingest_took"] = response["ingest_took"] if bulk_error_count > 0: stats["error-type"] = "bulk" stats["error-description"] = self.error_description(error_details) @@ -574,7 +590,7 @@ class ForceMerge(Runner): Runs a force merge operation against Elasticsearch. """ - def __call__(self, es, params): + async def __call__(self, es, params): import elasticsearch max_num_segments = params.get("max-num-segments") # preliminary support for overriding the global request timeout (see #567). As force-merge falls back to @@ -583,18 +599,18 @@ def __call__(self, es, params): request_timeout = params.get("request-timeout") try: if max_num_segments: - es.indices.forcemerge(index=params.get("index"), max_num_segments=max_num_segments, request_timeout=request_timeout) + await es.indices.forcemerge(index=params.get("index"), max_num_segments=max_num_segments, request_timeout=request_timeout) else: - es.indices.forcemerge(index=params.get("index"), request_timeout=request_timeout) + await es.indices.forcemerge(index=params.get("index"), request_timeout=request_timeout) except elasticsearch.TransportError as e: # this is caused by older versions of Elasticsearch (< 2.1), fall back to optimize if e.status_code == 400: params = {"request_timeout": request_timeout} if max_num_segments: - es.transport.perform_request("POST", "/_optimize?max_num_segments={}".format(max_num_segments), - params=params) + await es.transport.perform_request("POST", "/_optimize?max_num_segments={}".format(max_num_segments), + params=params) else: - es.transport.perform_request("POST", "/_optimize", params=params) + await es.transport.perform_request("POST", "/_optimize", params=params) else: raise e @@ -618,11 +634,11 @@ def _get(self, v, path): def _safe_string(self, v): return str(v) if v is not None else None - def __call__(self, es, params): + async def __call__(self, es, params): index = params.get("index", "_all") condition = params.get("condition") - response = es.indices.stats(index=index, metric="_all") + response = await es.indices.stats(index=index, metric="_all") if condition: path = mandatory(condition, "path", repr(self)) expected_value = mandatory(condition, "expected-value", repr(self)) @@ -655,20 +671,50 @@ class NodeStats(Runner): Gather node stats for all nodes. """ - def __call__(self, es, params): - es.nodes.stats(metric="_all") + async def __call__(self, es, params): + await es.nodes.stats(metric="_all") def __repr__(self, *args, **kwargs): return "node-stats" -def search_type_fallback(es, doc_type, index, body, params): - if doc_type and not index: - index = "_all" - path = "/%s/%s/_search" % (index, doc_type) - return es.transport.perform_request( - "GET", path, params=params, body=body - ) +def parse(text, props, lists=None): + """ + Selectively parsed the provided text as JSON extracting only the properties provided in ``props``. If ``lists`` is + specified, this function determines whether the provided lists are empty (respective value will be ``True``) or + contain elements (respective key will be ``False``). + + :param text: A text to parse. + :param props: A mandatory list of property paths (separated by a dot character) for which to extract values. + :param lists: An optional list of property paths to JSON lists in the provided text. + :return: A dict containing all properties and lists that have been found in the provided text. + """ + text.seek(0) + parser = ijson.parse(text) + parsed = {} + parsed_lists = {} + current_list = None + expect_end_array = False + try: + for prefix, event, value in parser: + if expect_end_array: + # True if the list is empty, False otherwise + parsed_lists[current_list] = event == "end_array" + expect_end_array = False + if prefix in props: + parsed[prefix] = value + elif lists is not None and prefix in lists and event == "start_array": + current_list = prefix + expect_end_array = True + # found all necessary properties + if len(parsed) == len(props) and (lists is None or len(parsed_lists) == len(lists)): + break + except ijson.IncompleteJSONError: + # did not find all properties + pass + + parsed.update(parsed_lists) + return parsed class Query(Runner): @@ -682,6 +728,13 @@ class Query(Runner): * `cache`: True iff the request cache should be used. * `body`: Query body + The following parameters are optional: + + * `detailed-results` (default: ``False``): Records more detailed meta-data about queries. As it analyzes the + corresponding response in more detail, this might incur additional + overhead which can skew measurement results. This flag is ineffective + for scroll queries (detailed meta-data are always returned). + If the following parameters are present in addition, a scroll query will be issued: * `pages`: Number of pages to retrieve at most for this scroll. If a scroll query does yield less results than the specified number of @@ -707,92 +760,128 @@ class Query(Runner): def __init__(self): super().__init__() - self.scroll_id = None - self.es = None - def __call__(self, es, params): + async def __call__(self, es, params): if "pages" in params and "results-per-page" in params: - return self.scroll_query(es, params) + return await self.scroll_query(es, params) else: - return self.request_body_query(es, params) + return await self.request_body_query(es, params) - def request_body_query(self, es, params): + async def request_body_query(self, es, params): request_params = self._default_request_params(params) index = params.get("index", "_all") body = mandatory(params, "body", self) doc_type = params.get("type") + detailed_results = params.get("detailed-results", False) params = request_params + # disable eager response parsing - responses might be huge thus skewing results + es.return_raw_response() + if doc_type is not None: - r = search_type_fallback(es, doc_type, index, body, params) + r = await self._search_type_fallback(es, doc_type, index, body, params) else: - r = es.search(index=index, body=body, params=params) - hits = r["hits"]["total"] - if isinstance(hits, dict): - hits_total = hits["value"] - hits_relation = hits["relation"] + r = await es.search(index=index, body=body, params=params) + + if detailed_results: + props = parse(r, ["hits.total", "hits.total.value", "hits.total.relation", "timed_out", "took"]) + hits_total = props.get("hits.total.value", props.get("hits.total", 0)) + hits_relation = props.get("hits.total.relation", "eq") + timed_out = props.get("timed_out", False) + took = props.get("took", 0) + + return { + "weight": 1, + "unit": "ops", + "success": True, + "hits": hits_total, + "hits_relation": hits_relation, + "timed_out": timed_out, + "took": took + } else: - hits_total = hits - hits_relation = "eq" - return { - "weight": 1, - "unit": "ops", - "hits": hits_total, - "hits_relation": hits_relation, - "timed_out": r["timed_out"], - "took": r["took"] - } + return { + "weight": 1, + "unit": "ops", + "success": True + } - def scroll_query(self, es, params): + async def scroll_query(self, es, params): request_params = self._default_request_params(params) hits = 0 + hits_relation = None retrieved_pages = 0 timed_out = False took = 0 - self.es = es # explicitly convert to int to provoke an error otherwise total_pages = sys.maxsize if params["pages"] == "all" else int(params["pages"]) size = params.get("results-per-page") + scroll_id = None - for page in range(total_pages): - if page == 0: - index = params.get("index", "_all") - body = mandatory(params, "body", self) - sort = "_doc" - scroll = "10s" - doc_type = params.get("type") - params = request_params - if doc_type is not None: - params["sort"] = sort - params["scroll"] = scroll - params["size"] = size - r = search_type_fallback(es, doc_type, index, body, params) + # disable eager response parsing - responses might be huge thus skewing results + es.return_raw_response() + + try: + for page in range(total_pages): + if page == 0: + index = params.get("index", "_all") + body = mandatory(params, "body", self) + sort = "_doc" + scroll = "10s" + doc_type = params.get("type") + params = request_params + if doc_type is not None: + params["sort"] = sort + params["scroll"] = scroll + params["size"] = size + r = await self._search_type_fallback(es, doc_type, index, body, params) + else: + r = await es.search(index=index, body=body, params=params, sort=sort, scroll=scroll, size=size) + + props = parse(r, + ["_scroll_id", "hits.total", "hits.total.value", "hits.total.relation", "timed_out", "took"], + ["hits.hits"]) + scroll_id = props.get("_scroll_id") + hits = props.get("hits.total.value", props.get("hits.total", 0)) + hits_relation = props.get("hits.total.relation", "eq") + timed_out = props.get("timed_out", False) + took = props.get("took", 0) + all_results_collected = (size is not None and hits < size) or hits == 0 else: - r = es.search(index=index, body=body, params=params, sort=sort, scroll=scroll, size=size) - # This should only happen if we concurrently create an index and start searching - self.scroll_id = r.get("_scroll_id", None) - else: - r = es.scroll(body={"scroll_id": self.scroll_id, "scroll": "10s"}) - hit_count = len(r["hits"]["hits"]) - timed_out = timed_out or r["timed_out"] - took += r["took"] - hits += hit_count - retrieved_pages += 1 - if hit_count == 0: - # We're done prematurely. Even if we are on page index zero, we still made one call. - break + r = await es.scroll(body={"scroll_id": scroll_id, "scroll": "10s"}) + props = parse(r, ["hits.total", "hits.total.value", "hits.total.relation", "timed_out", "took"], ["hits.hits"]) + timed_out = timed_out or props.get("timed_out", False) + took += props.get("took", 0) + # is the list of hits empty? + all_results_collected = props.get("hits.hits", False) + retrieved_pages += 1 + if all_results_collected: + break + finally: + if scroll_id: + # noinspection PyBroadException + try: + await es.clear_scroll(body={"scroll_id": [scroll_id]}) + except BaseException: + self.logger.exception("Could not clear scroll [%s]. This will lead to excessive resource usage in " + "Elasticsearch and will skew your benchmark results.", scroll_id) return { "weight": retrieved_pages, "pages": retrieved_pages, "hits": hits, - # as Rally determines the number of hits in a scroll, the result is always accurate. - "hits_relation": "eq", + "hits_relation": hits_relation, "unit": "pages", "timed_out": timed_out, "took": took } + async def _search_type_fallback(self, es, doc_type, index, body, params): + if doc_type and not index: + index = "_all" + path = "/%s/%s/_search" % (index, doc_type) + return await es.transport.perform_request("GET", path, params=params, body=body) + def _default_request_params(self, params): request_params = params.get("request-params", {}) cache = params.get("cache") @@ -800,17 +889,6 @@ def _default_request_params(self, params): request_params["request_cache"] = str(cache).lower() return request_params - def __exit__(self, exc_type, exc_val, exc_tb): - if self.scroll_id and self.es: - try: - self.es.clear_scroll(body={"scroll_id": [self.scroll_id]}) - except BaseException: - self.logger.exception("Could not clear scroll [%s]. This will lead to excessive resource usage in " - "Elasticsearch and will skew your benchmark results.", self.scroll_id) - self.scroll_id = None - self.es = None - return False - def __repr__(self, *args, **kwargs): return "query" @@ -820,7 +898,7 @@ class ClusterHealth(Runner): Get cluster health """ - def __call__(self, es, params): + async def __call__(self, es, params): from enum import Enum from functools import total_ordering @@ -855,7 +933,7 @@ def status(v): # either the user has defined something or we're good with any count of relocating shards. expected_relocating_shards = int(request_params.get("wait_for_relocating_shards", sys.maxsize)) - result = es.cluster.health(index=index, params=request_params) + result = await es.cluster.health(index=index, params=request_params) cluster_status = result["status"] relocating_shards = result["relocating_shards"] @@ -877,12 +955,12 @@ class PutPipeline(Runner): API is only available from Elasticsearch 5.0 onwards. """ - def __call__(self, es, params): - es.ingest.put_pipeline(id=mandatory(params, "id", self), - body=mandatory(params, "body", self), - master_timeout=params.get("master-timeout"), - timeout=params.get("timeout"), - ) + async def __call__(self, es, params): + await es.ingest.put_pipeline(id=mandatory(params, "id", self), + body=mandatory(params, "body", self), + master_timeout=params.get("master-timeout"), + timeout=params.get("timeout"), + ) def __repr__(self, *args, **kwargs): return "put-pipeline" @@ -893,8 +971,8 @@ class Refresh(Runner): Execute the `refresh API `_. """ - def __call__(self, es, params): - es.indices.refresh(index=params.get("index", "_all")) + async def __call__(self, es, params): + await es.indices.refresh(index=params.get("index", "_all")) def __repr__(self, *args, **kwargs): return "refresh" @@ -905,11 +983,11 @@ class CreateIndex(Runner): Execute the `create index API `_. """ - def __call__(self, es, params): + async def __call__(self, es, params): indices = mandatory(params, "indices", self) request_params = params.get("request-params", {}) for index, body in indices: - es.indices.create(index=index, body=body, params=request_params) + await es.indices.create(index=index, body=body, params=request_params) return len(indices), "ops" def __repr__(self, *args, **kwargs): @@ -921,7 +999,7 @@ class DeleteIndex(Runner): Execute the `delete index API `_. """ - def __call__(self, es, params): + async def __call__(self, es, params): ops = 0 indices = mandatory(params, "indices", self) @@ -930,11 +1008,11 @@ def __call__(self, es, params): for index_name in indices: if not only_if_exists: - es.indices.delete(index=index_name, params=request_params) + await es.indices.delete(index=index_name, params=request_params) ops += 1 - elif only_if_exists and es.indices.exists(index=index_name): + elif only_if_exists and await es.indices.exists(index=index_name): self.logger.info("Index [%s] already exists. Deleting it.", index_name) - es.indices.delete(index=index_name, params=request_params) + await es.indices.delete(index=index_name, params=request_params) ops += 1 return ops, "ops" @@ -948,13 +1026,13 @@ class CreateIndexTemplate(Runner): Execute the `PUT index template API `_. """ - def __call__(self, es, params): + async def __call__(self, es, params): templates = mandatory(params, "templates", self) request_params = params.get("request-params", {}) for template, body in templates: - es.indices.put_template(name=template, - body=body, - params=request_params) + await es.indices.put_template(name=template, + body=body, + params=request_params) return len(templates), "ops" def __repr__(self, *args, **kwargs): @@ -967,7 +1045,7 @@ class DeleteIndexTemplate(Runner): `_. """ - def __call__(self, es, params): + async def __call__(self, es, params): template_names = mandatory(params, "templates", self) only_if_exists = params.get("only-if-exists", False) request_params = params.get("request-params", {}) @@ -975,15 +1053,15 @@ def __call__(self, es, params): for template_name, delete_matching_indices, index_pattern in template_names: if not only_if_exists: - es.indices.delete_template(name=template_name, params=request_params) + await es.indices.delete_template(name=template_name, params=request_params) ops_count += 1 - elif only_if_exists and es.indices.exists_template(template_name): + elif only_if_exists and await es.indices.exists_template(template_name): self.logger.info("Index template [%s] already exists. Deleting it.", template_name) - es.indices.delete_template(name=template_name, params=request_params) + await es.indices.delete_template(name=template_name, params=request_params) ops_count += 1 # ensure that we do not provide an empty index pattern by accident if delete_matching_indices and index_pattern: - es.indices.delete(index=index_pattern) + await es.indices.delete(index=index_pattern) ops_count += 1 return ops_count, "ops" @@ -1003,10 +1081,10 @@ def __init__(self): super().__init__() self.cluster_health = Retry(ClusterHealth()) - def _wait_for(self, es, idx, description): + async def _wait_for(self, es, idx, description): # wait a little bit before the first check - time.sleep(3) - result = self.cluster_health(es, params={ + await asyncio.sleep(3) + result = await self.cluster_health(es, params={ "index": idx, "retries": sys.maxsize, "request-params": { @@ -1016,7 +1094,7 @@ def _wait_for(self, es, idx, description): if not result["success"]: raise exceptions.RallyAssertionError("Failed to wait for [{}].".format(description)) - def __call__(self, es, params): + async def __call__(self, es, params): source_index = mandatory(params, "source-index", self) target_index = mandatory(params, "target-index", self) # we need to inject additional settings so we better copy the body @@ -1026,7 +1104,8 @@ def __call__(self, es, params): if not shrink_node: node_names = [] # choose a random data node - for node in es.nodes.info()["nodes"].values(): + node_info = await es.nodes.info() + for node in node_info["nodes"].values(): if "data" in node["roles"]: node_names.append(node["name"]) if not node_names: @@ -1035,27 +1114,27 @@ def __call__(self, es, params): self.logger.info("Using [%s] as shrink node.", shrink_node) self.logger.info("Preparing [%s] for shrinking.", source_index) # prepare index for shrinking - es.indices.put_settings(index=source_index, - body={ - "settings": { - "index.routing.allocation.require._name": shrink_node, - "index.blocks.write": "true" - } - }, - preserve_existing=True) + await es.indices.put_settings(index=source_index, + body={ + "settings": { + "index.routing.allocation.require._name": shrink_node, + "index.blocks.write": "true" + } + }, + preserve_existing=True) self.logger.info("Waiting for relocation to finish for index [%s]...", source_index) - self._wait_for(es, source_index, "shard relocation for index [{}]".format(source_index)) + await self._wait_for(es, source_index, "shard relocation for index [{}]".format(source_index)) self.logger.info("Shrinking [%s] to [%s].", source_index, target_index) if "settings" not in target_body: target_body["settings"] = {} target_body["settings"]["index.routing.allocation.require._name"] = None target_body["settings"]["index.blocks.write"] = None # kick off the shrink operation - es.indices.shrink(index=source_index, target=target_index, body=target_body) + await es.indices.shrink(index=source_index, target=target_index, body=target_body) self.logger.info("Waiting for shrink to finish for index [%s]...", source_index) - self._wait_for(es, target_index, "shrink for index [{}]".format(target_index)) + await self._wait_for(es, target_index, "shrink for index [{}]".format(target_index)) self.logger.info("Shrinking [%s] to [%s] has finished.", source_index, target_index) # ops_count is not really important for this operation... return 1, "ops" @@ -1069,16 +1148,16 @@ class CreateMlDatafeed(Runner): Execute the `create datafeed API `_. """ - def __call__(self, es, params): + async def __call__(self, es, params): import elasticsearch datafeed_id = mandatory(params, "datafeed-id", self) body = mandatory(params, "body", self) try: - es.xpack.ml.put_datafeed(datafeed_id=datafeed_id, body=body) + await es.xpack.ml.put_datafeed(datafeed_id=datafeed_id, body=body) except elasticsearch.TransportError as e: # fallback to old path if e.status_code == 400: - es.transport.perform_request( + await es.transport.perform_request( "PUT", "/_xpack/ml/datafeeds/%s" % datafeed_id, params=params, @@ -1096,19 +1175,19 @@ class DeleteMlDatafeed(Runner): Execute the `delete datafeed API `_. """ - def __call__(self, es, params): + async def __call__(self, es, params): import elasticsearch datafeed_id = mandatory(params, "datafeed-id", self) force = params.get("force", False) try: # we don't want to fail if a datafeed does not exist, thus we ignore 404s. - es.xpack.ml.delete_datafeed(datafeed_id=datafeed_id, force=force, ignore=[404]) + await es.xpack.ml.delete_datafeed(datafeed_id=datafeed_id, force=force, ignore=[404]) except elasticsearch.TransportError as e: # fallback to old path (ES < 7) if e.status_code == 400: - es.transport.perform_request( + await es.transport.perform_request( "DELETE", - "/_xpack/ml/datafeeds/%s" %datafeed_id, + "/_xpack/ml/datafeeds/%s" % datafeed_id, params=params, ) else: @@ -1123,7 +1202,7 @@ class StartMlDatafeed(Runner): Execute the `start datafeed API `_. """ - def __call__(self, es, params): + async def __call__(self, es, params): import elasticsearch datafeed_id = mandatory(params, "datafeed-id", self) body = params.get("body") @@ -1131,11 +1210,11 @@ def __call__(self, es, params): end = params.get("end") timeout = params.get("timeout") try: - es.xpack.ml.start_datafeed(datafeed_id=datafeed_id, body=body, start=start, end=end, timeout=timeout) + await es.xpack.ml.start_datafeed(datafeed_id=datafeed_id, body=body, start=start, end=end, timeout=timeout) except elasticsearch.TransportError as e: # fallback to old path (ES < 7) if e.status_code == 400: - es.transport.perform_request( + await es.transport.perform_request( "POST", "/_xpack/ml/datafeeds/%s/_start" % datafeed_id, params=params, @@ -1153,17 +1232,17 @@ class StopMlDatafeed(Runner): Execute the `stop datafeed API `_. """ - def __call__(self, es, params): + async def __call__(self, es, params): import elasticsearch datafeed_id = mandatory(params, "datafeed-id", self) force = params.get("force", False) timeout = params.get("timeout") try: - es.xpack.ml.stop_datafeed(datafeed_id=datafeed_id, force=force, timeout=timeout) + await es.xpack.ml.stop_datafeed(datafeed_id=datafeed_id, force=force, timeout=timeout) except elasticsearch.TransportError as e: # fallback to old path (ES < 7) if e.status_code == 400: - es.transport.perform_request( + await es.transport.perform_request( "POST", "/_xpack/ml/datafeeds/%s/_stop" % datafeed_id, params=params @@ -1180,16 +1259,16 @@ class CreateMlJob(Runner): Execute the `create job API `_. """ - def __call__(self, es, params): + async def __call__(self, es, params): import elasticsearch job_id = mandatory(params, "job-id", self) body = mandatory(params, "body", self) try: - es.xpack.ml.put_job(job_id=job_id, body=body) + await es.xpack.ml.put_job(job_id=job_id, body=body) except elasticsearch.TransportError as e: # fallback to old path (ES < 7) if e.status_code == 400: - es.transport.perform_request( + await es.transport.perform_request( "PUT", "/_xpack/ml/anomaly_detectors/%s" % job_id, params=params, @@ -1207,13 +1286,13 @@ class DeleteMlJob(Runner): Execute the `delete job API `_. """ - def __call__(self, es, params): + async def __call__(self, es, params): import elasticsearch job_id = mandatory(params, "job-id", self) force = params.get("force", False) # we don't want to fail if a job does not exist, thus we ignore 404s. try: - es.xpack.ml.delete_job(job_id=job_id, force=force, ignore=[404]) + await es.xpack.ml.delete_job(job_id=job_id, force=force, ignore=[404]) except elasticsearch.TransportError as e: # fallback to old path (ES < 7) if e.status_code == 400: @@ -1234,15 +1313,15 @@ class OpenMlJob(Runner): Execute the `open job API `_. """ - def __call__(self, es, params): + async def __call__(self, es, params): import elasticsearch job_id = mandatory(params, "job-id", self) try: - es.xpack.ml.open_job(job_id=job_id) + await es.xpack.ml.open_job(job_id=job_id) except elasticsearch.TransportError as e: # fallback to old path (ES < 7) if e.status_code == 400: - es.transport.perform_request( + await es.transport.perform_request( "POST", "/_xpack/ml/anomaly_detectors/%s/_open" % job_id, params=params, @@ -1259,17 +1338,17 @@ class CloseMlJob(Runner): Execute the `close job API `_. """ - def __call__(self, es, params): + async def __call__(self, es, params): import elasticsearch job_id = mandatory(params, "job-id", self) force = params.get("force", False) timeout = params.get("timeout") try: - es.xpack.ml.close_job(job_id=job_id, force=force, timeout=timeout) + await es.xpack.ml.close_job(job_id=job_id, force=force, timeout=timeout) except elasticsearch.TransportError as e: # fallback to old path (ES < 7) if e.status_code == 400: - es.transport.perform_request( + await es.transport.perform_request( "POST", "/_xpack/ml/anomaly_detectors/%s/_close" % job_id, params=params, @@ -1282,17 +1361,17 @@ def __repr__(self, *args, **kwargs): class RawRequest(Runner): - def __call__(self, es, params): + async def __call__(self, es, params): request_params = {} if "ignore" in params: request_params["ignore"] = params["ignore"] request_params.update(params.get("request-params", {})) - es.transport.perform_request(method=params.get("method", "GET"), - url=mandatory(params, "path", self), - headers=params.get("headers"), - body=params.get("body"), - params=request_params) + await es.transport.perform_request(method=params.get("method", "GET"), + url=mandatory(params, "path", self), + headers=params.get("headers"), + body=params.get("body"), + params=request_params) def __repr__(self, *args, **kwargs): return "raw-request" @@ -1303,8 +1382,8 @@ class Sleep(Runner): Sleeps for the specified duration not issuing any request. """ - def __call__(self, es, params): - time.sleep(mandatory(params, "duration", "sleep")) + async def __call__(self, es, params): + await asyncio.sleep(mandatory(params, "duration", "sleep")) def __repr__(self, *args, **kwargs): return "sleep" @@ -1314,8 +1393,8 @@ class DeleteSnapshotRepository(Runner): """ Deletes a snapshot repository """ - def __call__(self, es, params): - es.snapshot.delete_repository(repository=mandatory(params, "repository", repr(self))) + async def __call__(self, es, params): + await es.snapshot.delete_repository(repository=mandatory(params, "repository", repr(self))) def __repr__(self, *args, **kwargs): return "delete-snapshot-repository" @@ -1325,12 +1404,11 @@ class CreateSnapshotRepository(Runner): """ Creates a new snapshot repository """ - def __call__(self, es, params): + async def __call__(self, es, params): request_params = params.get("request-params", {}) - es.snapshot.create_repository( - repository=mandatory(params, "repository", repr(self)), - body=mandatory(params, "body", repr(self)), - params=request_params) + await es.snapshot.create_repository(repository=mandatory(params, "repository", repr(self)), + body=mandatory(params, "body", repr(self)), + params=request_params) def __repr__(self, *args, **kwargs): return "create-snapshot-repository" @@ -1340,13 +1418,13 @@ class RestoreSnapshot(Runner): """ Restores a snapshot from an already registered repository """ - def __call__(self, es, params): + async def __call__(self, es, params): request_params = params.get("request-params", {}) - es.snapshot.restore(repository=mandatory(params, "repository", repr(self)), - snapshot=mandatory(params, "snapshot", repr(self)), - body=params.get("body"), - wait_for_completion=params.get("wait-for-completion", False), - params=request_params) + await es.snapshot.restore(repository=mandatory(params, "repository", repr(self)), + snapshot=mandatory(params, "snapshot", repr(self)), + body=params.get("body"), + wait_for_completion=params.get("wait-for-completion", False), + params=request_params) def __repr__(self, *args, **kwargs): return "restore-snapshot" @@ -1367,17 +1445,17 @@ def completed(self): def percent_completed(self): return self._percent_completed - def __call__(self, es, params): + async def __call__(self, es, params): remaining_attempts = params.get("completion-recheck-attempts", 3) wait_period = params.get("completion-recheck-wait-period", 2) response = None while not response and remaining_attempts > 0: - response = es.indices.recovery(active_only=True) + response = await es.indices.recovery(active_only=True) remaining_attempts -= 1 # This might also happen if all recoveries have just finished and we happen to call the API # before the next recovery is scheduled. if not response: - time.sleep(wait_period) + await asyncio.sleep(wait_period) if not response: self._completed = True @@ -1415,8 +1493,8 @@ class PutSettings(Runner): Updates cluster settings with the `cluster settings API _. """ - def __call__(self, es, params): - es.cluster.put_settings(body=mandatory(params, "body", repr(self))) + async def __call__(self, es, params): + await es.cluster.put_settings(body=mandatory(params, "body", repr(self))) def __repr__(self, *args, **kwargs): return "put-settings" @@ -1442,11 +1520,11 @@ class Retry(Runner, Delegator): def __init__(self, delegate): super().__init__(delegate=delegate) - def __enter__(self): - self.delegate.__enter__() + async def __aenter__(self): + await self.delegate.__aenter__() return self - def __call__(self, es, params): + async def __call__(self, es, params): import elasticsearch import socket @@ -1463,7 +1541,7 @@ def __call__(self, es, params): for attempt in range(max_attempts): last_attempt = attempt + 1 == max_attempts try: - return_value = self.delegate(es, params) + return_value = await self.delegate(es, params) if last_attempt or not retry_on_error: return return_value # we can determine success if and only if the runner returns a dict. Otherwise, we have to assume it was fine. @@ -1473,25 +1551,25 @@ def __call__(self, es, params): return return_value else: self.logger.debug("%s has returned with an error: %s.", repr(self.delegate), return_value) - time.sleep(sleep_time) + await asyncio.sleep(sleep_time) else: return return_value except (socket.timeout, elasticsearch.exceptions.ConnectionError): if last_attempt or not retry_on_timeout: raise else: - time.sleep(sleep_time) + await asyncio.sleep(sleep_time) except elasticsearch.exceptions.TransportError as e: if last_attempt or not retry_on_timeout: raise e elif e.status_code == 408: self.logger.debug("%s has timed out.", repr(self.delegate)) - time.sleep(sleep_time) + await asyncio.sleep(sleep_time) else: raise e - def __exit__(self, exc_type, exc_val, exc_tb): - return self.delegate.__exit__(exc_type, exc_val, exc_tb) + async def __aexit__(self, exc_type, exc_val, exc_tb): + return await self.delegate.__aexit__(exc_type, exc_val, exc_tb) def __repr__(self, *args, **kwargs): return "retryable %s" % repr(self.delegate) diff --git a/esrally/metrics.py b/esrally/metrics.py index a4b6c8723..a011d134c 100644 --- a/esrally/metrics.py +++ b/esrally/metrics.py @@ -690,16 +690,17 @@ def _add(self, doc): """ raise NotImplementedError("abstract method") - def get_one(self, name, sample_type=None, node_name=None): + def get_one(self, name, sample_type=None, node_name=None, task=None): """ Gets one value for the given metric name (even if there should be more than one). :param name: The metric name to query. :param sample_type The sample type to query. Optional. By default, all samples are considered. :param node_name The name of the node where this metric was gathered. Optional. + :param task The task name to query. Optional. :return: The corresponding value for the given metric name or None if there is no value. """ - return self._first_or_none(self.get(name=name, sample_type=sample_type, node_name=node_name)) + return self._first_or_none(self.get(name=name, task=task, sample_type=sample_type, node_name=node_name)) @staticmethod def _first_or_none(values): @@ -1641,6 +1642,7 @@ def __call__(self): self.summary_stats("throughput", t), self.single_latency(t), self.single_latency(t, metric_name="service_time"), + self.single_latency(t, metric_name="processing_time"), self.error_rate(t), self.merge( self.track.meta_data, @@ -1843,6 +1845,8 @@ def op_metrics(op_item, key, single_value=False): all_results.append(op_metrics(item, "latency")) if "service_time" in item: all_results.append(op_metrics(item, "service_time")) + if "processing_time" in item: + all_results.append(op_metrics(item, "processing_time")) if "error_rate" in item: all_results.append(op_metrics(item, "error_rate", single_value=True)) elif metric == "ml_processing_time": @@ -1874,13 +1878,14 @@ def op_metrics(op_item, key, single_value=False): def v(self, d, k, default=None): return d.get(k, default) if d else default - def add_op_metrics(self, task, operation, throughput, latency, service_time, error_rate, meta): + def add_op_metrics(self, task, operation, throughput, latency, service_time, processing_time, error_rate, meta): doc = { "task": task, "operation": operation, "throughput": throughput, "latency": latency, "service_time": service_time, + "processing_time": processing_time, "error_rate": error_rate, } if meta: diff --git a/esrally/racecontrol.py b/esrally/racecontrol.py index 15e99eb80..2ec319a10 100644 --- a/esrally/racecontrol.py +++ b/esrally/racecontrol.py @@ -97,19 +97,15 @@ class BenchmarkActor(actor.RallyActor): def __init__(self): super().__init__() self.cfg = None - self.race = None - self.metrics_store = None - self.race_store = None - self.cancelled = False - self.error = False self.start_sender = None self.mechanic = None self.main_driver = None - self.track_revision = None + self.coordinator = None def receiveMsg_PoisonMessage(self, msg, sender): self.logger.info("BenchmarkActor got notified of poison message [%s] (forwarding).", (str(msg))) - self.error = True + if self.coordinator: + self.coordinator.error = True self.send(self.start_sender, msg) def receiveUnrecognizedMessage(self, msg, sender): @@ -117,43 +113,46 @@ def receiveUnrecognizedMessage(self, msg, sender): @actor.no_retry("race control") def receiveMsg_Setup(self, msg, sender): - self.setup(msg, sender) + self.start_sender = sender + self.cfg = msg.cfg + self.coordinator = BenchmarkCoordinator(msg.cfg) + self.coordinator.setup(sources=msg.sources) + self.logger.info("Asking mechanic to start the engine.") + cluster_settings = self.coordinator.current_challenge.cluster_settings + self.mechanic = self.createActor(mechanic.MechanicActor, targetActorRequirements={"coordinator": True}) + self.send(self.mechanic, mechanic.StartEngine(self.cfg, + self.coordinator.metrics_store.open_context, + cluster_settings, + msg.sources, + msg.build, + msg.distribution, + msg.external, + msg.docker)) @actor.no_retry("race control") def receiveMsg_EngineStarted(self, msg, sender): self.logger.info("Mechanic has started engine successfully.") - self.race.team_revision = msg.team_revision + self.coordinator.race.team_revision = msg.team_revision self.main_driver = self.createActor(driver.DriverActor, targetActorRequirements={"coordinator": True}) self.logger.info("Telling driver to prepare for benchmarking.") - self.send(self.main_driver, driver.PrepareBenchmark(self.cfg, self.race.track)) + self.send(self.main_driver, driver.PrepareBenchmark(self.cfg, self.coordinator.current_track)) @actor.no_retry("race control") def receiveMsg_PreparationComplete(self, msg, sender): - self.race.distribution_flavor = msg.distribution_flavor - self.race.distribution_version = msg.distribution_version - self.race.revision = msg.revision - # store race initially (without any results) so other components can retrieve full metadata - self.race_store.store_race(self.race) - if self.race.challenge.auto_generated: - console.info("Racing on track [{}] and car {} with version [{}].\n" - .format(self.race.track_name, self.race.car, self.race.distribution_version)) - else: - console.info("Racing on track [{}], challenge [{}] and car {} with version [{}].\n" - .format(self.race.track_name, self.race.challenge_name, self.race.car, self.race.distribution_version)) - self.run() + self.coordinator.on_preparation_complete(msg.distribution_flavor, msg.distribution_version, msg.revision) + self.logger.info("Telling driver to start benchmark.") + self.send(self.main_driver, driver.StartBenchmark()) @actor.no_retry("race control") def receiveMsg_TaskFinished(self, msg, sender): - self.logger.info("Task has finished.") - self.logger.info("Bulk adding request metrics to metrics store.") - self.metrics_store.bulk_add(msg.metrics) + self.coordinator.on_task_finished(msg.metrics) # We choose *NOT* to reset our own metrics store's timer as this one is only used to collect complete metrics records from # other stores (used by driver and mechanic). Hence there is no need to reset the timer in our own metrics store. self.send(self.mechanic, mechanic.ResetRelativeTime(msg.next_task_scheduled_in)) @actor.no_retry("race control") def receiveMsg_BenchmarkCancelled(self, msg, sender): - self.cancelled = True + self.coordinator.cancelled = True # even notify the start sender if it is the originator. The reason is that we call #ask() which waits for a reply. # We also need to ask in order to avoid races between this notification and the following ActorExitRequest. self.send(self.start_sender, msg) @@ -161,53 +160,56 @@ def receiveMsg_BenchmarkCancelled(self, msg, sender): @actor.no_retry("race control") def receiveMsg_BenchmarkFailure(self, msg, sender): self.logger.info("Received a benchmark failure from [%s] and will forward it now.", sender) - self.error = True + self.coordinator.error = True self.send(self.start_sender, msg) @actor.no_retry("race control") def receiveMsg_BenchmarkComplete(self, msg, sender): - self.logger.info("Benchmark is complete.") - self.logger.info("Bulk adding request metrics to metrics store.") - self.metrics_store.bulk_add(msg.metrics) - self.metrics_store.flush() - if not self.cancelled and not self.error: - final_results = metrics.calculate_results(self.metrics_store, self.race) - self.race.add_results(final_results) - self.race_store.store_race(self.race) - metrics.results_store(self.cfg).store_results(self.race) - reporter.summarize(final_results, self.cfg) - else: - self.logger.info("Suppressing output of summary report. Cancelled = [%r], Error = [%r].", self.cancelled, self.error) - self.metrics_store.close() - - self.teardown() + self.coordinator.on_benchmark_complete(msg.metrics) + self.send(self.main_driver, thespian.actors.ActorExitRequest()) + self.main_driver = None + self.logger.info("Asking mechanic to stop the engine.") + self.send(self.mechanic, mechanic.StopEngine()) @actor.no_retry("race control") def receiveMsg_EngineStopped(self, msg, sender): self.logger.info("Mechanic has stopped engine successfully.") self.send(self.start_sender, Success()) - def setup(self, msg, sender): - self.start_sender = sender - self.cfg = msg.cfg - # to load the track we need to know the correct cluster distribution version. Usually, this value should be set but there are rare - # cases (external pipeline and user did not specify the distribution version) where we need to derive it ourselves. For source - # builds we always assume "master" - if not msg.sources and not self.cfg.exists("mechanic", "distribution.version"): + +class BenchmarkCoordinator: + def __init__(self, cfg): + self.logger = logging.getLogger(__name__) + self.cfg = cfg + self.race = None + self.metrics_store = None + self.race_store = None + self.cancelled = False + self.error = False + self.track_revision = None + self.current_track = None + self.current_challenge = None + + def setup(self, sources=False): + # to load the track we need to know the correct cluster distribution version. Usually, this value should be set + # but there are rare cases (external pipeline and user did not specify the distribution version) where we need + # to derive it ourselves. For source builds we always assume "master" + if not sources and not self.cfg.exists("mechanic", "distribution.version"): distribution_version = mechanic.cluster_distribution_version(self.cfg) self.logger.info("Automatically derived distribution version [%s]", distribution_version) self.cfg.add(config.Scope.benchmark, "mechanic", "distribution.version", distribution_version) - t = track.load_track(self.cfg) + self.current_track = track.load_track(self.cfg) self.track_revision = self.cfg.opts("track", "repository.revision", mandatory=False) challenge_name = self.cfg.opts("track", "challenge.name") - challenge = t.find_challenge_or_default(challenge_name) - if challenge is None: - raise exceptions.SystemSetupError("Track [%s] does not provide challenge [%s]. List the available tracks with %s list tracks." - % (t.name, challenge_name, PROGRAM_NAME)) - if challenge.user_info: - console.info(challenge.user_info) - self.race = metrics.create_race(self.cfg, t, challenge, self.track_revision) + self.current_challenge = self.current_track.find_challenge_or_default(challenge_name) + if self.current_challenge is None: + raise exceptions.SystemSetupError( + "Track [{}] does not provide challenge [{}]. List the available tracks with {} list tracks.".format( + self.current_track.name, challenge_name, PROGRAM_NAME)) + if self.current_challenge.user_info: + console.info(self.current_challenge.user_info) + self.race = metrics.create_race(self.cfg, self.current_track, self.current_challenge, self.track_revision) self.metrics_store = metrics.metrics_store( self.cfg, @@ -216,21 +218,39 @@ def setup(self, msg, sender): read_only=False ) self.race_store = metrics.race_store(self.cfg) - self.logger.info("Asking mechanic to start the engine.") - cluster_settings = challenge.cluster_settings - self.mechanic = self.createActor(mechanic.MechanicActor, targetActorRequirements={"coordinator": True}) - self.send(self.mechanic, mechanic.StartEngine(self.cfg, self.metrics_store.open_context, cluster_settings, msg.sources, msg.build, - msg.distribution, msg.external, msg.docker)) - def run(self): - self.logger.info("Telling driver to start benchmark.") - self.send(self.main_driver, driver.StartBenchmark()) + def on_preparation_complete(self, distribution_flavor, distribution_version, revision): + self.race.distribution_flavor = distribution_flavor + self.race.distribution_version = distribution_version + self.race.revision = revision + # store race initially (without any results) so other components can retrieve full metadata + self.race_store.store_race(self.race) + if self.race.challenge.auto_generated: + console.info("Racing on track [{}] and car {} with version [{}].\n" + .format(self.race.track_name, self.race.car, self.race.distribution_version)) + else: + console.info("Racing on track [{}], challenge [{}] and car {} with version [{}].\n" + .format(self.race.track_name, self.race.challenge_name, self.race.car, self.race.distribution_version)) - def teardown(self): - self.send(self.main_driver, thespian.actors.ActorExitRequest()) - self.main_driver = None - self.logger.info("Asking mechanic to stop the engine.") - self.send(self.mechanic, mechanic.StopEngine()) + def on_task_finished(self, new_metrics): + self.logger.info("Task has finished.") + self.logger.info("Bulk adding request metrics to metrics store.") + self.metrics_store.bulk_add(new_metrics) + + def on_benchmark_complete(self, new_metrics): + self.logger.info("Benchmark is complete.") + self.logger.info("Bulk adding request metrics to metrics store.") + self.metrics_store.bulk_add(new_metrics) + self.metrics_store.flush() + if not self.cancelled and not self.error: + final_results = metrics.calculate_results(self.metrics_store, self.race) + self.race.add_results(final_results) + self.race_store.store_race(self.race) + metrics.results_store(self.cfg).store_results(self.race) + reporter.summarize(final_results, self.cfg) + else: + self.logger.info("Suppressing output of summary report. Cancelled = [%r], Error = [%r].", self.cancelled, self.error) + self.metrics_store.close() def race(cfg, sources=False, build=False, distribution=False, external=False, docker=False): @@ -368,3 +388,25 @@ def run(cfg): except BaseException: tb = sys.exc_info()[2] raise exceptions.RallyError("This race ended with a fatal crash.").with_traceback(tb) + + +def run_async(cfg): + console.warn("The race-async command is experimental.") + logger = logging.getLogger(__name__) + # We'll use a special car name for external benchmarks. + cfg.add(config.Scope.benchmark, "mechanic", "car.names", ["external"]) + coordinator = BenchmarkCoordinator(cfg) + + try: + coordinator.setup() + race_driver = driver.AsyncDriver(cfg, coordinator.current_track, coordinator.current_challenge) + distribution_flavor, distribution_version, revision = race_driver.setup() + coordinator.on_preparation_complete(distribution_flavor, distribution_version, revision) + + new_metrics = race_driver.run() + coordinator.on_benchmark_complete(new_metrics) + except KeyboardInterrupt: + logger.info("User has cancelled the benchmark.") + except BaseException as e: + tb = sys.exc_info()[2] + raise exceptions.RallyError(str(e)).with_traceback(tb) diff --git a/esrally/rally.py b/esrally/rally.py index 432c4b3d4..fe499014c 100644 --- a/esrally/rally.py +++ b/esrally/rally.py @@ -67,6 +67,7 @@ def runtime_jdk(v): help="") race_parser = subparsers.add_parser("race", help="Run the benchmarking pipeline. This sub-command should typically be used.") + async_race_parser = subparsers.add_parser("race-async") # change in favor of "list telemetry", "list tracks", "list pipelines" list_parser = subparsers.add_parser("list", help="List configuration options") list_parser.add_argument( @@ -345,7 +346,7 @@ def runtime_jdk(v): default=preserve_install, action="store_true") - for p in [parser, list_parser, race_parser, generate_parser]: + for p in [parser, list_parser, race_parser, async_race_parser, generate_parser]: p.add_argument( "--distribution-version", help="Define the version of the Elasticsearch distribution to download. " @@ -385,7 +386,7 @@ def runtime_jdk(v): default=False, action="store_true") - for p in [parser, race_parser]: + for p in [parser, race_parser, async_race_parser]: p.add_argument( "--race-id", help="Define a unique id for this race.", @@ -517,7 +518,7 @@ def runtime_jdk(v): # The options below are undocumented and can be removed or changed at any time. # ############################################################################### - for p in [parser, race_parser]: + for p in [parser, race_parser, async_race_parser]: # This option is intended to tell Rally to assume a different start date than 'now'. This is effectively just useful for things like # backtesting or a benchmark run across environments (think: comparison of EC2 and bare metal) but never for the typical user. p.add_argument( @@ -539,7 +540,7 @@ def runtime_jdk(v): default=False) for p in [parser, config_parser, list_parser, race_parser, compare_parser, download_parser, install_parser, - start_parser, stop_parser, info_parser, generate_parser]: + start_parser, stop_parser, info_parser, generate_parser, async_race_parser]: # This option is needed to support a separate configuration for the integration tests on the same machine p.add_argument( "--configuration-name", @@ -708,6 +709,8 @@ def dispatch_sub_command(cfg, sub_command): mechanic.stop(cfg) elif sub_command == "race": race(cfg) + elif sub_command == "race-async": + racecontrol.run_async(cfg) elif sub_command == "generate": generate(cfg) elif sub_command == "info": diff --git a/esrally/reporter.py b/esrally/reporter.py index 1b0a339bc..78ce9752c 100644 --- a/esrally/reporter.py +++ b/esrally/reporter.py @@ -99,6 +99,8 @@ def __init__(self, results, config): reporting_values = config.opts("reporting", "values") self.report_all_values = reporting_values == "all" self.report_all_percentile_values = reporting_values == "all-percentiles" + self.show_processing_time = convert.to_bool(config.opts("reporting", "output.processingtime", + mandatory=False, default_value=False)) self.cwd = config.opts("node", "rally.cwd") def report(self): @@ -122,6 +124,9 @@ def report(self): metrics_table.extend(self.report_throughput(record, task)) metrics_table.extend(self.report_latency(record, task)) metrics_table.extend(self.report_service_time(record, task)) + # this is mostly needed for debugging purposes but not so relevant to end users + if self.show_processing_time: + metrics_table.extend(self.report_processing_time(record, task)) metrics_table.extend(self.report_error_rate(record, task)) self.add_warnings(warnings, record, task) @@ -162,6 +167,9 @@ def report_latency(self, values, task): def report_service_time(self, values, task): return self.report_percentiles("service time", task, values["service_time"]) + def report_processing_time(self, values, task): + return self.report_percentiles("processing time", task, values["processing_time"]) + def report_percentiles(self, name, task, value): lines = [] if value: diff --git a/esrally/track/loader.py b/esrally/track/loader.py index c746f71ca..5817c3bd7 100644 --- a/esrally/track/loader.py +++ b/esrally/track/loader.py @@ -939,8 +939,8 @@ def load(self): def register_param_source(self, name, param_source): params.register_param_source_for_name(name, param_source) - def register_runner(self, name, runner): - self.runner_registry(name, runner) + def register_runner(self, name, runner, **kwargs): + self.runner_registry(name, runner, **kwargs) def register_scheduler(self, name, scheduler): self.scheduler_registry(name, scheduler) @@ -950,7 +950,8 @@ def meta_data(self): from esrally import version return { - "rally_version": version.release_version() + "rally_version": version.release_version(), + "async_runner": True } diff --git a/esrally/track/params.py b/esrally/track/params.py index 08618630e..23540f66e 100644 --- a/esrally/track/params.py +++ b/esrally/track/params.py @@ -644,7 +644,7 @@ def chain(*iterables): def create_default_reader(docs, offset, num_lines, num_docs, batch_size, bulk_size, id_conflicts, conflict_probability, on_conflict, recency): - source = Slice(io.FileSource, offset, num_lines) + source = Slice(io.MmapSource, offset, num_lines) if docs.includes_action_and_meta_data: return SourceOnlyIndexDataReader(docs.document_file, batch_size, bulk_size, source, docs.target_index, docs.target_type) @@ -905,7 +905,7 @@ def __next__(self): if docs_in_bulk == 0: break docs_in_batch += docs_in_bulk - batch.append((docs_in_bulk, "".join(bulk))) + batch.append((docs_in_bulk, b"".join(bulk))) if docs_in_batch == 0: raise StopIteration() return self.index_name, self.type_name, batch @@ -938,7 +938,7 @@ def _read_bulk_fast(self): """ current_bulk = [] # hoist - action_metadata_line = self.action_metadata_line + action_metadata_line = self.action_metadata_line.encode("utf-8") docs = next(self.file_source) for doc in docs: @@ -957,11 +957,11 @@ def _read_bulk_regular(self): action_metadata_item = next(self.action_metadata) if action_metadata_item: action_type, action_metadata_line = action_metadata_item - current_bulk.append(action_metadata_line) + current_bulk.append(action_metadata_line.encode("utf-8")) if action_type == "update": # remove the trailing "\n" as the doc needs to fit on one line doc = doc.strip() - current_bulk.append("{\"doc\":%s}\n" % doc) + current_bulk.append(b"{\"doc\":%s}\n" % doc) else: current_bulk.append(doc) else: diff --git a/esrally/utils/io.py b/esrally/utils/io.py index ea134126d..62a13c213 100644 --- a/esrally/utils/io.py +++ b/esrally/utils/io.py @@ -22,6 +22,9 @@ import subprocess import tarfile import zipfile +from contextlib import suppress + +import mmap from esrally.utils import console @@ -76,6 +79,64 @@ def __str__(self, *args, **kwargs): return self.file_name +class MmapSource: + """ + MmapSource is a wrapper around a memory-mapped file which simplifies testing of file I/O calls. + """ + def __init__(self, file_name, mode, encoding="utf-8"): + self.file_name = file_name + self.mode = mode + self.encoding = encoding + self.f = None + self.mm = None + + def open(self): + self.f = open(self.file_name, mode="r+b") + self.mm = mmap.mmap(self.f.fileno(), 0, access=mmap.ACCESS_READ) + # madvise is available in Python 3.8+ + with suppress(AttributeError): + self.mm.madvise(mmap.MADV_SEQUENTIAL) + + # allow for chaining + return self + + def seek(self, offset): + self.mm.seek(offset) + + def read(self): + return self.mm.read() + + def readline(self): + return self.mm.readline() + + def readlines(self, num_lines): + lines = [] + mm = self.mm + for _ in range(num_lines): + line = mm.readline() + if line == b"": + break + lines.append(line) + return lines + + def close(self): + self.mm.close() + self.mm = None + self.f.close() + self.f = None + + def __enter__(self): + self.open() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + def __str__(self, *args, **kwargs): + return self.file_name + + class DictStringFileSourceFactory: """ Factory that can create `StringAsFileSource` for tests. Based on the provided dict, it will create a proper `StringAsFileSource`. diff --git a/integration-test.sh b/integration-test.sh index 00e4e852a..7a27eafd8 100755 --- a/integration-test.sh +++ b/integration-test.sh @@ -512,11 +512,11 @@ function test_node_management_commands { info "test start [--configuration-name=${cfg}]" esrally start --quiet --configuration-name="${cfg}" --installation-id="${install_id}" --race-id="rally-integration-test" - esrally --target-host="localhost:39200" \ + esrally race-async \ + --target-host="localhost:39200" \ --configuration-name="${cfg}" \ --race-id="rally-integration-test" \ --on-error=abort \ - --pipeline=benchmark-only \ --track=geonames \ --test-mode \ --challenge=append-no-conflicts-index-only diff --git a/setup.cfg b/setup.cfg index 8270bbd0c..d4e57b23a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,6 +2,9 @@ test=pytest [tool:pytest] +# set to true for more verbose output of tests +log_cli=false +log_level=INFO addopts = --verbose --color=yes testpaths = tests junit_family = xunit2 diff --git a/setup.py b/setup.py index 1916dbcf2..8ae887779 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,11 @@ def str_from_file(name): # License: Apache 2.0 # transitive dependency urllib3: MIT "elasticsearch==7.0.5", + # License: Apache 2.0 + # transitive dependencies: + # aiohttp: Apache 2.0 + # async_timeout: Apache 2.0 + "elasticsearch-async==6.2.0", # License: BSD "psutil==5.7.0", # License: MIT @@ -73,7 +78,11 @@ def str_from_file(name): # botocore: Apache 2.0 # jmespath: MIT # s3transfer: Apache 2.0 - "boto3==1.10.32" + "boto3==1.10.32", + # License: Apache 2.0 + "yappi==1.2.3", + # License: BSD + "ijson==2.6.1" ] tests_require = [ diff --git a/tests/__init__.py b/tests/__init__.py index a2833546a..c0e50348c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -14,3 +14,35 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import asyncio + + +def run_async(t): + """ + A wrapper that ensures that a test is run in an asyncio context. + + :param t: The test case to wrap. + """ + def async_wrapper(*args, **kwargs): + asyncio.run(t(*args, **kwargs), debug=True) + return async_wrapper + + +def as_future(result=None, exception=None): + """ + + Helper to create a future that completes immediately either with a result or exceptionally. + + :param result: Regular result. + :param exception: Exceptional result. + :return: The corresponding future. + """ + f = asyncio.Future() + if exception and result: + raise AssertionError("Specify a result or an exception but not both") + if exception: + f.set_exception(exception) + else: + f.set_result(result) + return f diff --git a/tests/driver/async_driver_test.py b/tests/driver/async_driver_test.py new file mode 100644 index 000000000..8a4fb4dfd --- /dev/null +++ b/tests/driver/async_driver_test.py @@ -0,0 +1,213 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import concurrent.futures +import io +import json +import time +from datetime import datetime +from unittest import TestCase, mock + +from esrally import config, metrics +from esrally.driver import async_driver +from esrally.track import track, params +from tests import as_future + + +class TimerTests(TestCase): + class Counter: + def __init__(self): + self.count = 0 + + def __call__(self): + self.count += 1 + + def test_scheduled_tasks(self): + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + + timer = async_driver.Timer(wakeup_interval=0.1) + counter = TimerTests.Counter() + timer.add_task(fn=counter, interval=0.2) + + pool.submit(timer) + + time.sleep(0.45) + timer.stop() + pool.shutdown() + + self.assertEqual(2, counter.count) + + +class StaticClientFactory: + SYNC_PATCHER = None + ASYNC_PATCHER = None + + def __init__(self, *args, **kwargs): + StaticClientFactory.SYNC_PATCHER = mock.patch("elasticsearch.Elasticsearch") + self.es = StaticClientFactory.SYNC_PATCHER.start() + self.es.indices.stats.return_value = {"mocked": True} + self.es.info.return_value = { + "cluster_name": "elasticsearch", + "version": { + "number": "7.3.0", + "build_flavor": "oss", + "build_type": "tar", + "build_hash": "de777fa", + "build_date": "2019-07-24T18:30:11.767338Z", + "build_snapshot": False, + "lucene_version": "8.1.0", + "minimum_wire_compatibility_version": "6.8.0", + "minimum_index_compatibility_version": "6.0.0-beta1" + } + } + + StaticClientFactory.ASYNC_PATCHER = mock.patch("elasticsearch.Elasticsearch") + self.es_async = StaticClientFactory.ASYNC_PATCHER.start() + # we want to simulate that the request took 10 seconds. Internally this is measured using `time#perf_counter` + # and the code relies that measurements are taken consistently with `time#perf_counter` because in some places + # we take a value and will subtract other measurements (e.g. in the main loop in AsyncExecutor we subtract + # `total_start` from `stop`. + # + # On some systems (MacOS), `time#perf_counter` starts at zero when the process is started but on others (Linux), + # `time#perf_counter` starts when the OS is started. Thus we need to ensure that this value here is roughly + # consistent across all systems by using the current value of `time#perf_counter` as basis. + + start = time.perf_counter() + self.es_async.init_request_context.return_value = { + "request_start": start, + "request_end": start + 10 + } + bulk_response = { + "errors": False, + "took": 5 + } + # bulk responses are raw strings + self.es_async.bulk.return_value = as_future(io.StringIO(json.dumps(bulk_response))) + self.es_async.transport.close.return_value = as_future() + + def create(self): + return self.es + + def create_async(self): + return self.es_async + + @classmethod + def close(cls): + StaticClientFactory.SYNC_PATCHER.stop() + StaticClientFactory.ASYNC_PATCHER.stop() + + +class AsyncDriverTestParamSource: + def __init__(self, track=None, params=None, **kwargs): + if params is None: + params = {} + self._indices = track.indices + self._params = params + self._current = 1 + self._total = params.get("size") + self.infinite = self._total is None + + def partition(self, partition_index, total_partitions): + return self + + @property + def percent_completed(self): + if self.infinite: + return None + return self._current / self._total + + def params(self): + if not self.infinite and self._current > self._total: + raise StopIteration() + self._current += 1 + return self._params + + +class AsyncDriverTests(TestCase): + class Holder: + def __init__(self, all_hosts=None, all_client_options=None): + self.all_hosts = all_hosts + self.all_client_options = all_client_options + + def test_run_benchmark(self): + cfg = config.Config() + + cfg.add(config.Scope.application, "system", "env.name", "unittest") + cfg.add(config.Scope.application, "system", "time.start", + datetime(year=2017, month=8, day=20, hour=1, minute=0, second=0)) + cfg.add(config.Scope.application, "system", "race.id", "6ebc6e53-ee20-4b0c-99b4-09697987e9f4") + cfg.add(config.Scope.application, "system", "offline.mode", False) + cfg.add(config.Scope.application, "driver", "on.error", "abort") + cfg.add(config.Scope.application, "driver", "profiling", False) + cfg.add(config.Scope.application, "reporting", "datastore.type", "in-memory") + cfg.add(config.Scope.application, "track", "params", {}) + cfg.add(config.Scope.application, "track", "test.mode.enabled", True) + cfg.add(config.Scope.application, "telemetry", "devices", []) + cfg.add(config.Scope.application, "telemetry", "params", {}) + cfg.add(config.Scope.application, "mechanic", "car.names", ["external"]) + cfg.add(config.Scope.application, "mechanic", "skip.rest.api.check", True) + cfg.add(config.Scope.application, "client", "hosts", + AsyncDriverTests.Holder(all_hosts={"default": ["localhost:9200"]})) + cfg.add(config.Scope.application, "client", "options", + AsyncDriverTests.Holder(all_client_options={"default": {}})) + + params.register_param_source_for_name("bulk-param-source", AsyncDriverTestParamSource) + + task = track.Task(name="bulk-index", + operation=track.Operation( + "bulk-index", + track.OperationType.Bulk.name, + params={ + "body": ["action_metadata_line", "index_line"], + "action-metadata-present": True, + "bulk-size": 1, + # we need this because the parameter source does not know that we only have one + # bulk and hence size() returns incorrect results + "size": 1 + }, + param_source="bulk-param-source"), + warmup_iterations=0, + iterations=1, + clients=1) + + current_challenge = track.Challenge(name="default", default=True, schedule=[task]) + current_track = track.Track(name="unit-test", challenges=[current_challenge]) + + driver = async_driver.AsyncDriver(cfg, current_track, current_challenge, + es_client_factory_class=StaticClientFactory) + + distribution_flavor, distribution_version, revision = driver.setup() + self.assertEqual("oss", distribution_flavor) + self.assertEqual("7.3.0", distribution_version) + self.assertEqual("de777fa", revision) + + metrics_store_representation = driver.run() + + metric_store = metrics.metrics_store(cfg, read_only=True, track=current_track, challenge=current_challenge) + metric_store.bulk_add(metrics_store_representation) + + self.assertIsNotNone(metric_store.get_one(name="latency", task="bulk-index", sample_type=metrics.SampleType.Normal)) + self.assertIsNotNone(metric_store.get_one(name="service_time", task="bulk-index", sample_type=metrics.SampleType.Normal)) + self.assertIsNotNone(metric_store.get_one(name="processing_time", task="bulk-index", sample_type=metrics.SampleType.Normal)) + self.assertIsNotNone(metric_store.get_one(name="throughput", task="bulk-index", sample_type=metrics.SampleType.Normal)) + self.assertIsNotNone(metric_store.get_one(name="node_total_young_gen_gc_time", sample_type=metrics.SampleType.Normal)) + self.assertIsNotNone(metric_store.get_one(name="node_total_old_gen_gc_time", sample_type=metrics.SampleType.Normal)) + # ensure that there are not more documents than we expect + self.assertEqual(6, len(metric_store.docs), msg=json.dumps(metric_store.docs, indent=2)) + + def tearDown(self): + StaticClientFactory.close() diff --git a/tests/driver/driver_test.py b/tests/driver/driver_test.py index 5be335035..2f15a05b0 100644 --- a/tests/driver/driver_test.py +++ b/tests/driver/driver_test.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. +import asyncio import collections +import io import threading import time import unittest.mock as mock @@ -25,6 +27,7 @@ from esrally import metrics, track, exceptions, config from esrally.driver import driver, runner from esrally.track import params +from tests import run_async, as_future class DriverTestParamSource: @@ -218,25 +221,6 @@ def test_client_reaches_join_point_which_completes_parent(self): self.assertEqual(4, target.drive_at.call_count) -class ScheduleTestCase(TestCase): - def assert_schedule(self, expected_schedule, schedule, infinite_schedule=False): - if not infinite_schedule: - self.assertEqual(len(expected_schedule), len(schedule), - msg="Number of elements in the schedules do not match") - idx = 0 - for invocation_time, sample_type, progress_percent, runner, params in schedule: - exp_invocation_time, exp_sample_type, exp_progress_percent, exp_params = expected_schedule[idx] - self.assertAlmostEqual(exp_invocation_time, invocation_time, msg="Invocation time for sample at index %d does not match" % idx) - self.assertEqual(exp_sample_type, sample_type, "Sample type for sample at index %d does not match" % idx) - self.assertEqual(exp_progress_percent, progress_percent, "Current progress for sample at index %d does not match" % idx) - self.assertIsNotNone(runner, "runner must be defined") - self.assertEqual(exp_params, params, "Parameters do not match") - idx += 1 - # for infinite schedules we only check the first few elements - if infinite_schedule and idx == len(expected_schedule): - break - - def op(name, operation_type): return track.Operation(name, operation_type, param_source="driver-test-param-source") @@ -400,8 +384,8 @@ def test_different_sample_types(self): op = track.Operation("index", track.OperationType.Bulk, param_source="driver-test-param-source") samples = [ - driver.Sample(0, 1470838595, 21, op, metrics.SampleType.Warmup, None, -1, -1, 3000, "docs", 1, 1), - driver.Sample(0, 1470838595.5, 21.5, op, metrics.SampleType.Normal, None, -1, -1, 2500, "docs", 1, 1), + driver.Sample(0, 1470838595, 21, op, metrics.SampleType.Warmup, None, -1, -1, -1, 3000, "docs", 1, 1), + driver.Sample(0, 1470838595.5, 21.5, op, metrics.SampleType.Normal, None, -1, -1, -1, 2500, "docs", 1, 1), ] aggregated = self.calculate_global_throughput(samples) @@ -418,15 +402,15 @@ def test_single_metrics_aggregation(self): op = track.Operation("index", track.OperationType.Bulk, param_source="driver-test-param-source") samples = [ - driver.Sample(0, 1470838595, 21, op, metrics.SampleType.Normal, None, -1, -1, 5000, "docs", 1, 1 / 9), - driver.Sample(0, 1470838596, 22, op, metrics.SampleType.Normal, None, -1, -1, 5000, "docs", 2, 2 / 9), - driver.Sample(0, 1470838597, 23, op, metrics.SampleType.Normal, None, -1, -1, 5000, "docs", 3, 3 / 9), - driver.Sample(0, 1470838598, 24, op, metrics.SampleType.Normal, None, -1, -1, 5000, "docs", 4, 4 / 9), - driver.Sample(0, 1470838599, 25, op, metrics.SampleType.Normal, None, -1, -1, 5000, "docs", 5, 5 / 9), - driver.Sample(0, 1470838600, 26, op, metrics.SampleType.Normal, None, -1, -1, 5000, "docs", 6, 6 / 9), - driver.Sample(1, 1470838598.5, 24.5, op, metrics.SampleType.Normal, None, -1, -1, 5000, "docs", 4.5, 7 / 9), - driver.Sample(1, 1470838599.5, 25.5, op, metrics.SampleType.Normal, None, -1, -1, 5000, "docs", 5.5, 8 / 9), - driver.Sample(1, 1470838600.5, 26.5, op, metrics.SampleType.Normal, None, -1, -1, 5000, "docs", 6.5, 9 / 9) + driver.Sample(0, 38595, 21, op, metrics.SampleType.Normal, None, -1, -1, -1, 5000, "docs", 1, 1 / 9), + driver.Sample(0, 38596, 22, op, metrics.SampleType.Normal, None, -1, -1, -1, 5000, "docs", 2, 2 / 9), + driver.Sample(0, 38597, 23, op, metrics.SampleType.Normal, None, -1, -1, -1, 5000, "docs", 3, 3 / 9), + driver.Sample(0, 38598, 24, op, metrics.SampleType.Normal, None, -1, -1, -1, 5000, "docs", 4, 4 / 9), + driver.Sample(0, 38599, 25, op, metrics.SampleType.Normal, None, -1, -1, -1, 5000, "docs", 5, 5 / 9), + driver.Sample(0, 38600, 26, op, metrics.SampleType.Normal, None, -1, -1, -1, 5000, "docs", 6, 6 / 9), + driver.Sample(1, 38598.5, 24.5, op, metrics.SampleType.Normal, None, -1, -1, -1, 5000, "docs", 4.5, 7 / 9), + driver.Sample(1, 38599.5, 25.5, op, metrics.SampleType.Normal, None, -1, -1, -1, 5000, "docs", 5.5, 8 / 9), + driver.Sample(1, 38600.5, 26.5, op, metrics.SampleType.Normal, None, -1, -1, -1, 5000, "docs", 6.5, 9 / 9) ] aggregated = self.calculate_global_throughput(samples) @@ -436,19 +420,19 @@ def test_single_metrics_aggregation(self): throughput = aggregated[op] self.assertEqual(6, len(throughput)) - self.assertEqual((1470838595, 21, metrics.SampleType.Normal, 5000, "docs/s"), throughput[0]) - self.assertEqual((1470838596, 22, metrics.SampleType.Normal, 5000, "docs/s"), throughput[1]) - self.assertEqual((1470838597, 23, metrics.SampleType.Normal, 5000, "docs/s"), throughput[2]) - self.assertEqual((1470838598, 24, metrics.SampleType.Normal, 5000, "docs/s"), throughput[3]) - self.assertEqual((1470838599, 25, metrics.SampleType.Normal, 6000, "docs/s"), throughput[4]) - self.assertEqual((1470838600, 26, metrics.SampleType.Normal, 6666.666666666667, "docs/s"), throughput[5]) + self.assertEqual((38595, 21, metrics.SampleType.Normal, 5000, "docs/s"), throughput[0]) + self.assertEqual((38596, 22, metrics.SampleType.Normal, 5000, "docs/s"), throughput[1]) + self.assertEqual((38597, 23, metrics.SampleType.Normal, 5000, "docs/s"), throughput[2]) + self.assertEqual((38598, 24, metrics.SampleType.Normal, 5000, "docs/s"), throughput[3]) + self.assertEqual((38599, 25, metrics.SampleType.Normal, 6000, "docs/s"), throughput[4]) + self.assertEqual((38600, 26, metrics.SampleType.Normal, 6666.666666666667, "docs/s"), throughput[5]) # self.assertEqual((1470838600.5, 26.5, metrics.SampleType.Normal, 10000), throughput[6]) def calculate_global_throughput(self, samples): return driver.ThroughputCalculator().calculate(samples) -class SchedulerTests(ScheduleTestCase): +class SchedulerTests(TestCase): class RunnerWithProgress: def __init__(self, complete_after=3): self.completed = False @@ -456,7 +440,7 @@ def __init__(self, complete_after=3): self.calls = 0 self.complete_after = complete_after - def __call__(self, *args, **kwargs): + async def __call__(self, *args, **kwargs): self.calls += 1 if not self.completed: self.percent_completed = self.calls / self.complete_after @@ -464,21 +448,37 @@ def __call__(self, *args, **kwargs): else: self.percent_completed = 1.0 + async def assert_schedule(self, expected_schedule, schedule, infinite_schedule=False): + idx = 0 + async for invocation_time, sample_type, progress_percent, runner, params in schedule: + exp_invocation_time, exp_sample_type, exp_progress_percent, exp_params = expected_schedule[idx] + self.assertAlmostEqual(exp_invocation_time, invocation_time, msg="Invocation time for sample at index %d does not match" % idx) + self.assertEqual(exp_sample_type, sample_type, "Sample type for sample at index %d does not match" % idx) + self.assertEqual(exp_progress_percent, progress_percent, "Current progress for sample at index %d does not match" % idx) + self.assertIsNotNone(runner, "runner must be defined") + self.assertEqual(exp_params, params, "Parameters do not match") + idx += 1 + # for infinite schedules we only check the first few elements + if infinite_schedule and idx == len(expected_schedule): + break + if not infinite_schedule: + self.assertEqual(len(expected_schedule), idx, msg="Number of elements in the schedules do not match") + def setUp(self): self.test_track = track.Track(name="unittest") self.runner_with_progress = SchedulerTests.RunnerWithProgress() params.register_param_source_for_name("driver-test-param-source", DriverTestParamSource) runner.register_default_runners() - runner.register_runner("driver-test-runner-with-completion", self.runner_with_progress) + runner.register_runner("driver-test-runner-with-completion", self.runner_with_progress, async_runner=True) def tearDown(self): runner.remove_runner("driver-test-runner-with-completion") - def test_search_task_one_client(self): + @run_async + async def test_search_task_one_client(self): task = track.Task("search", track.Operation("search", track.OperationType.Search.name, param_source="driver-test-param-source"), warmup_iterations=3, iterations=5, clients=1, params={"target-throughput": 10, "clients": 1}) - schedule_handle = driver.schedule_for(self.test_track, task, 0) - schedule = schedule_handle() + schedule = driver.schedule_for(self.test_track, task, 0) expected_schedule = [ (0, metrics.SampleType.Warmup, 1 / 8, {}), @@ -490,13 +490,13 @@ def test_search_task_one_client(self): (0.6, metrics.SampleType.Normal, 7 / 8, {}), (0.7, metrics.SampleType.Normal, 8 / 8, {}), ] - self.assert_schedule(expected_schedule, list(schedule)) + await self.assert_schedule(expected_schedule, schedule()) - def test_search_task_two_clients(self): + @run_async + async def test_search_task_two_clients(self): task = track.Task("search", track.Operation("search", track.OperationType.Search.name, param_source="driver-test-param-source"), warmup_iterations=1, iterations=5, clients=2, params={"target-throughput": 10, "clients": 2}) - schedule_handle = driver.schedule_for(self.test_track, task, 0) - schedule = schedule_handle() + schedule = driver.schedule_for(self.test_track, task, 0) expected_schedule = [ (0, metrics.SampleType.Warmup, 1 / 6, {}), @@ -506,61 +506,61 @@ def test_search_task_two_clients(self): (0.8, metrics.SampleType.Normal, 5 / 6, {}), (1.0, metrics.SampleType.Normal, 6 / 6, {}), ] - self.assert_schedule(expected_schedule, list(schedule)) + await self.assert_schedule(expected_schedule, schedule()) - def test_schedule_param_source_determines_iterations_no_warmup(self): + @run_async + async def test_schedule_param_source_determines_iterations_no_warmup(self): # we neither define any time-period nor any iteration count on the task. task = track.Task("bulk-index", track.Operation("bulk-index", track.OperationType.Bulk.name, params={"body": ["a"], "size": 3}, param_source="driver-test-param-source"), clients=1, params={"target-throughput": 4, "clients": 4}) - schedule_handle = driver.schedule_for(self.test_track, task, 0) - schedule = schedule_handle() + schedule = driver.schedule_for(self.test_track, task, 0) - self.assert_schedule([ + await self.assert_schedule([ (0.0, metrics.SampleType.Normal, 1 / 3, {"body": ["a"], "size": 3}), (1.0, metrics.SampleType.Normal, 2 / 3, {"body": ["a"], "size": 3}), (2.0, metrics.SampleType.Normal, 3 / 3, {"body": ["a"], "size": 3}), - ], list(schedule)) + ], schedule()) - def test_schedule_param_source_determines_iterations_including_warmup(self): + @run_async + async def test_schedule_param_source_determines_iterations_including_warmup(self): task = track.Task("bulk-index", track.Operation("bulk-index", track.OperationType.Bulk.name, params={"body": ["a"], "size": 5}, param_source="driver-test-param-source"), warmup_iterations=2, clients=1, params={"target-throughput": 4, "clients": 4}) - schedule_handle = driver.schedule_for(self.test_track, task, 0) - schedule = schedule_handle() + schedule = driver.schedule_for(self.test_track, task, 0) - self.assert_schedule([ + await self.assert_schedule([ (0.0, metrics.SampleType.Warmup, 1 / 5, {"body": ["a"], "size": 5}), (1.0, metrics.SampleType.Warmup, 2 / 5, {"body": ["a"], "size": 5}), (2.0, metrics.SampleType.Normal, 3 / 5, {"body": ["a"], "size": 5}), (3.0, metrics.SampleType.Normal, 4 / 5, {"body": ["a"], "size": 5}), (4.0, metrics.SampleType.Normal, 5 / 5, {"body": ["a"], "size": 5}), - ], list(schedule)) + ], schedule()) - def test_schedule_defaults_to_iteration_based(self): + @run_async + async def test_schedule_defaults_to_iteration_based(self): # no time-period and no iterations specified on the task. Also, the parameter source does not define a size. task = track.Task("bulk-index", track.Operation("bulk-index", track.OperationType.Bulk.name, params={"body": ["a"]}, param_source="driver-test-param-source"), clients=1, params={"target-throughput": 4, "clients": 4}) - schedule_handle = driver.schedule_for(self.test_track, task, 0) - schedule = schedule_handle() + schedule = driver.schedule_for(self.test_track, task, 0) - self.assert_schedule([ + await self.assert_schedule([ (0.0, metrics.SampleType.Normal, 1 / 1, {"body": ["a"]}), - ], list(schedule)) + ], schedule()) - def test_schedule_for_warmup_time_based(self): + @run_async + async def test_schedule_for_warmup_time_based(self): task = track.Task("time-based", track.Operation("time-based", track.OperationType.Bulk.name, params={"body": ["a"], "size": 11}, param_source="driver-test-param-source"), warmup_time_period=0, clients=4, params={"target-throughput": 4, "clients": 4}) - schedule_handle = driver.schedule_for(self.test_track, task, 0) - schedule = schedule_handle() + schedule = driver.schedule_for(self.test_track, task, 0) - self.assert_schedule([ + await self.assert_schedule([ (0.0, metrics.SampleType.Normal, 1 / 11, {"body": ["a"], "size": 11}), (1.0, metrics.SampleType.Normal, 2 / 11, {"body": ["a"], "size": 11}), (2.0, metrics.SampleType.Normal, 3 / 11, {"body": ["a"], "size": 11}), @@ -572,71 +572,70 @@ def test_schedule_for_warmup_time_based(self): (8.0, metrics.SampleType.Normal, 9 / 11, {"body": ["a"], "size": 11}), (9.0, metrics.SampleType.Normal, 10 / 11, {"body": ["a"], "size": 11}), (10.0, metrics.SampleType.Normal, 11 / 11, {"body": ["a"], "size": 11}), - ], list(schedule)) + ], schedule()) - def test_infinite_schedule_without_progress_indication(self): + @run_async + async def test_infinite_schedule_without_progress_indication(self): task = track.Task("time-based", track.Operation("time-based", track.OperationType.Bulk.name, params={"body": ["a"]}, param_source="driver-test-param-source"), warmup_time_period=0, clients=4, params={"target-throughput": 4, "clients": 4}) - schedule_handle = driver.schedule_for(self.test_track, task, 0) - schedule = schedule_handle() + schedule = driver.schedule_for(self.test_track, task, 0) - self.assert_schedule([ + await self.assert_schedule([ (0.0, metrics.SampleType.Normal, None, {"body": ["a"]}), (1.0, metrics.SampleType.Normal, None, {"body": ["a"]}), (2.0, metrics.SampleType.Normal, None, {"body": ["a"]}), (3.0, metrics.SampleType.Normal, None, {"body": ["a"]}), (4.0, metrics.SampleType.Normal, None, {"body": ["a"]}), - ], schedule, infinite_schedule=True) + ], schedule(), infinite_schedule=True) - def test_finite_schedule_with_progress_indication(self): + @run_async + async def test_finite_schedule_with_progress_indication(self): task = track.Task("time-based", track.Operation("time-based", track.OperationType.Bulk.name, params={"body": ["a"], "size": 5}, param_source="driver-test-param-source"), warmup_time_period=0, clients=4, params={"target-throughput": 4, "clients": 4}) - schedule_handle = driver.schedule_for(self.test_track, task, 0) - schedule = schedule_handle() + schedule = driver.schedule_for(self.test_track, task, 0) - self.assert_schedule([ + await self.assert_schedule([ (0.0, metrics.SampleType.Normal, 1 / 5, {"body": ["a"], "size": 5}), (1.0, metrics.SampleType.Normal, 2 / 5, {"body": ["a"], "size": 5}), (2.0, metrics.SampleType.Normal, 3 / 5, {"body": ["a"], "size": 5}), (3.0, metrics.SampleType.Normal, 4 / 5, {"body": ["a"], "size": 5}), (4.0, metrics.SampleType.Normal, 5 / 5, {"body": ["a"], "size": 5}), - ], list(schedule), infinite_schedule=False) + ], schedule(), infinite_schedule=False) - def test_schedule_with_progress_determined_by_runner(self): + @run_async + async def test_schedule_with_progress_determined_by_runner(self): task = track.Task("time-based", track.Operation("time-based", "driver-test-runner-with-completion", params={"body": ["a"]}, param_source="driver-test-param-source"), clients=1, params={"target-throughput": 1, "clients": 1}) - schedule_handle = driver.schedule_for(self.test_track, task, 0) - schedule = schedule_handle() + schedule = driver.schedule_for(self.test_track, task, 0) - self.assert_schedule([ + await self.assert_schedule([ (0.0, metrics.SampleType.Normal, None, {"body": ["a"]}), (1.0, metrics.SampleType.Normal, None, {"body": ["a"]}), (2.0, metrics.SampleType.Normal, None, {"body": ["a"]}), (3.0, metrics.SampleType.Normal, None, {"body": ["a"]}), (4.0, metrics.SampleType.Normal, None, {"body": ["a"]}), - ], schedule, infinite_schedule=True) + ], schedule(), infinite_schedule=True) - def test_schedule_for_time_based(self): + @run_async + async def test_schedule_for_time_based(self): task = track.Task("time-based", track.Operation("time-based", track.OperationType.Bulk.name, params={"body": ["a"], "size": 11}, param_source="driver-test-param-source"), warmup_time_period=0.1, time_period=0.1, clients=1) schedule_handle = driver.schedule_for(self.test_track, task, 0) - schedule = list(schedule_handle()) - - self.assertTrue(len(schedule) > 0) + schedule = schedule_handle() last_progress = -1 - for invocation_time, sample_type, progress_percent, runner, params in schedule: + async for invocation_time, sample_type, progress_percent, runner, params in schedule: # we're not throughput throttled self.assertEqual(0, invocation_time) if progress_percent <= 0.5: @@ -651,18 +650,18 @@ def test_schedule_for_time_based(self): self.assertEqual({"body": ["a"], "size": 11}, params) -class ExecutorTests(TestCase): +class AsyncExecutorTests(TestCase): class NoopContextManager: def __init__(self, mock): self.mock = mock - def __enter__(self): + async def __aenter__(self): return self - def __call__(self, *args): - return self.mock(*args) + async def __call__(self, *args): + return await self.mock(*args) - def __exit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type, exc_val, exc_tb): return False def __str__(self): @@ -681,7 +680,7 @@ def completed(self): def percent_completed(self): return (self.iterations - self.iterations_left) / self.iterations - def __call__(self, es, params): + async def __call__(self, es, params): self.iterations_left -= 1 def __init__(self, methodName): @@ -689,18 +688,21 @@ def __init__(self, methodName): self.runner_with_progress = None def context_managed(self, mock): - return ExecutorTests.NoopContextManager(mock) + return AsyncExecutorTests.NoopContextManager(mock) def setUp(self): runner.register_default_runners() - self.runner_with_progress = ExecutorTests.RunnerWithProgress() - runner.register_runner("unit-test-recovery", self.runner_with_progress) + self.runner_with_progress = AsyncExecutorTests.RunnerWithProgress() + runner.register_runner("unit-test-recovery", self.runner_with_progress, async_runner=True) @mock.patch("elasticsearch.Elasticsearch") - def test_execute_schedule_in_throughput_mode(self, es): - es.bulk.return_value = { - "errors": False + @run_async + async def test_execute_schedule_in_throughput_mode(self, es): + es.init_request_context.return_value = { + "request_start": 0, + "request_end": 10 } + es.bulk.return_value = as_future(io.StringIO('{"errors": false, "took": 8}')) params.register_param_source_for_name("driver-test-param-source", DriverTestParamSource) test_track = track.Track(name="unittest", description="unittest track", @@ -718,12 +720,20 @@ def test_execute_schedule_in_throughput_mode(self, es): warmup_time_period=0, clients=4) schedule = driver.schedule_for(test_track, task, 0) - sampler = driver.Sampler(client_id=2, task=task, start_timestamp=time.perf_counter()) + sampler = driver.Sampler(start_timestamp=time.perf_counter()) cancel = threading.Event() complete = threading.Event() - execute_schedule = driver.Executor(task, schedule, es, sampler, cancel, complete) - execute_schedule() + execute_schedule = driver.AsyncExecutor(client_id=2, + task=task, + schedule=schedule, + es={ + "default": es + }, + sampler=sampler, + cancel=cancel, + complete=complete) + await execute_schedule() samples = sampler.samples @@ -747,10 +757,13 @@ def test_execute_schedule_in_throughput_mode(self, es): self.assertEqual(1, sample.request_meta_data["bulk-size"]) @mock.patch("elasticsearch.Elasticsearch") - def test_execute_schedule_with_progress_determined_by_runner(self, es): - es.bulk.return_value = { - "errors": False + @run_async + async def test_execute_schedule_with_progress_determined_by_runner(self, es): + es.init_request_context.return_value = { + "request_start": 0, + "request_end": 10 } + es.bulk.return_value = as_future(io.StringIO('{"errors": false, "took": 8}')) params.register_param_source_for_name("driver-test-param-source", DriverTestParamSource) test_track = track.Track(name="unittest", description="unittest track", @@ -764,12 +777,20 @@ def test_execute_schedule_with_progress_determined_by_runner(self, es): }, param_source="driver-test-param-source"), warmup_time_period=0, clients=4) schedule = driver.schedule_for(test_track, task, 0) - sampler = driver.Sampler(client_id=2, task=task, start_timestamp=time.perf_counter()) + sampler = driver.Sampler(start_timestamp=time.perf_counter()) cancel = threading.Event() complete = threading.Event() - execute_schedule = driver.Executor(task, schedule, es, sampler, cancel, complete) - execute_schedule() + execute_schedule = driver.AsyncExecutor(client_id=2, + task=task, + schedule=schedule, + es={ + "default": es + }, + sampler=sampler, + cancel=cancel, + complete=complete) + await execute_schedule() samples = sampler.samples @@ -794,10 +815,18 @@ def test_execute_schedule_with_progress_determined_by_runner(self, es): self.assertEqual("ops", sample.total_ops_unit) @mock.patch("elasticsearch.Elasticsearch") - def test_execute_schedule_throughput_throttled(self, es): - es.bulk.return_value = { - "errors": False + @run_async + async def test_execute_schedule_throughput_throttled(self, es): + def bulk(*args, **kwargs): + return as_future(io.StringIO('{"errors": false, "took": 8}')) + + es.init_request_context.return_value = { + "request_start": 0, + "request_end": 10 } + # as this method is called several times we need to return a fresh StringIO instance every time as the previous + # one has been "consumed". + es.bulk.side_effect = bulk params.register_param_source_for_name("driver-test-param-source", DriverTestParamSource) test_track = track.Track(name="unittest", description="unittest track", @@ -815,14 +844,22 @@ def test_execute_schedule_throughput_throttled(self, es): warmup_time_period=0.5, time_period=0.5, clients=4, params={"target-throughput": target_throughput, "clients": 4}, completes_parent=True) - sampler = driver.Sampler(client_id=0, task=task, start_timestamp=0) + sampler = driver.Sampler(start_timestamp=0) cancel = threading.Event() complete = threading.Event() schedule = driver.schedule_for(test_track, task, 0) - execute_schedule = driver.Executor(task, schedule, es, sampler, cancel, complete) - execute_schedule() + execute_schedule = driver.AsyncExecutor(client_id=0, + task=task, + schedule=schedule, + es={ + "default": es + }, + sampler=sampler, + cancel=cancel, + complete=complete) + await execute_schedule() samples = sampler.samples @@ -834,10 +871,13 @@ def test_execute_schedule_throughput_throttled(self, es): self.assertTrue(complete.is_set(), "Executor should auto-complete a task that terminates its parent") @mock.patch("elasticsearch.Elasticsearch") - def test_cancel_execute_schedule(self, es): - es.bulk.return_value = { - "errors": False + @run_async + async def test_cancel_execute_schedule(self, es): + es.init_request_context.return_value = { + "request_start": 0, + "request_end": 10 } + es.bulk.return_value = as_future(io.StringIO('{"errors": false, "took": 8}')) params.register_param_source_for_name("driver-test-param-source", DriverTestParamSource) test_track = track.Track(name="unittest", description="unittest track", @@ -855,14 +895,22 @@ def test_cancel_execute_schedule(self, es): warmup_time_period=0.5, time_period=0.5, clients=4, params={"target-throughput": target_throughput, "clients": 4}) schedule = driver.schedule_for(test_track, task, 0) - sampler = driver.Sampler(client_id=0, task=task, start_timestamp=0) + sampler = driver.Sampler(start_timestamp=0) cancel = threading.Event() complete = threading.Event() - execute_schedule = driver.Executor(task, schedule, es, sampler, cancel, complete) + execute_schedule = driver.AsyncExecutor(client_id=0, + task=task, + schedule=schedule, + es={ + "default": es + }, + sampler=sampler, + cancel=cancel, + complete=complete) cancel.set() - execute_schedule() + await execute_schedule() samples = sampler.samples @@ -870,86 +918,102 @@ def test_cancel_execute_schedule(self, es): self.assertEqual(0, sample_size) @mock.patch("elasticsearch.Elasticsearch") - def test_execute_schedule_aborts_on_error(self, es): + @run_async + async def test_execute_schedule_aborts_on_error(self, es): class ExpectedUnitTestException(Exception): pass def run(*args, **kwargs): raise ExpectedUnitTestException() - def schedule_handle(): - return [(0, metrics.SampleType.Warmup, 0, self.context_managed(run), None)] + async def schedule_handle(): + invocations = [(0, metrics.SampleType.Warmup, 0, self.context_managed(run), None)] + for invocation in invocations: + yield invocation task = track.Task("no-op", track.Operation("no-op", track.OperationType.Bulk.name, params={}, param_source="driver-test-param-source"), warmup_time_period=0.5, time_period=0.5, clients=4, params={"clients": 4}) - sampler = driver.Sampler(client_id=0, task=None, start_timestamp=0) + sampler = driver.Sampler(start_timestamp=0) cancel = threading.Event() complete = threading.Event() - execute_schedule = driver.Executor(task, schedule_handle, es, sampler, cancel, complete) + execute_schedule = driver.AsyncExecutor(client_id=2, + task=task, + schedule=schedule_handle, + es={ + "default": es + }, + sampler=sampler, + cancel=cancel, + complete=complete) with self.assertRaises(ExpectedUnitTestException): - execute_schedule() + await execute_schedule() self.assertEqual(0, es.call_count) - def test_execute_single_no_return_value(self): + @run_async + async def test_execute_single_no_return_value(self): es = None params = None runner = mock.Mock() + runner.return_value = as_future() - total_ops, total_ops_unit, request_meta_data = driver.execute_single(self.context_managed(runner), es, params) + ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params) - self.assertEqual(1, total_ops) - self.assertEqual("ops", total_ops_unit) + self.assertEqual(1, ops) + self.assertEqual("ops", unit) self.assertEqual({"success": True}, request_meta_data) - def test_execute_single_tuple(self): + @run_async + async def test_execute_single_tuple(self): es = None params = None runner = mock.Mock() - runner.return_value = (500, "MB") + runner.return_value = as_future(result=(500, "MB")) - total_ops, total_ops_unit, request_meta_data = driver.execute_single(self.context_managed(runner), es, params) + ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params) - self.assertEqual(500, total_ops) - self.assertEqual("MB", total_ops_unit) + self.assertEqual(500, ops) + self.assertEqual("MB", unit) self.assertEqual({"success": True}, request_meta_data) - def test_execute_single_dict(self): + @run_async + async def test_execute_single_dict(self): es = None params = None runner = mock.Mock() - runner.return_value = { + runner.return_value = as_future({ "weight": 50, "unit": "docs", "some-custom-meta-data": "valid", "http-status": 200 - } + }) - total_ops, total_ops_unit, request_meta_data = driver.execute_single(self.context_managed(runner), es, params) + ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params) - self.assertEqual(50, total_ops) - self.assertEqual("docs", total_ops_unit) + self.assertEqual(50, ops) + self.assertEqual("docs", unit) self.assertEqual({ "some-custom-meta-data": "valid", "http-status": 200, "success": True }, request_meta_data) - def test_execute_single_with_connection_error(self): + @run_async + async def test_execute_single_with_connection_error(self): import elasticsearch es = None params = None # ES client uses pseudo-status "N/A" in this case... - runner = mock.Mock(side_effect=elasticsearch.ConnectionError("N/A", "no route to host", None)) + runner = mock.Mock(side_effect=as_future(exception=elasticsearch.ConnectionError("N/A", "no route to host", None))) - total_ops, total_ops_unit, request_meta_data = driver.execute_single(self.context_managed(runner), es, params) + ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params) - self.assertEqual(0, total_ops) - self.assertEqual("ops", total_ops_unit) + self.assertEqual(0, ops) + self.assertEqual("ops", unit) self.assertEqual({ # Look ma: No http-status! "error-description": "no route to host", @@ -957,16 +1021,18 @@ def test_execute_single_with_connection_error(self): "success": False }, request_meta_data) - def test_execute_single_with_http_400(self): + @run_async + async def test_execute_single_with_http_400(self): import elasticsearch es = None params = None - runner = mock.Mock(side_effect=elasticsearch.NotFoundError(404, "not found", "the requested document could not be found")) + runner = mock.Mock(side_effect= + as_future(exception=elasticsearch.NotFoundError(404, "not found", "the requested document could not be found"))) - total_ops, total_ops_unit, request_meta_data = driver.execute_single(self.context_managed(runner), es, params) + ops, unit, request_meta_data = await driver.execute_single(self.context_managed(runner), es, params) - self.assertEqual(0, total_ops) - self.assertEqual("ops", total_ops_unit) + self.assertEqual(0, ops) + self.assertEqual("ops", unit) self.assertEqual({ "http-status": 404, "error-type": "transport", @@ -974,9 +1040,10 @@ def test_execute_single_with_http_400(self): "success": False }, request_meta_data) - def test_execute_single_with_key_error(self): + @run_async + async def test_execute_single_with_key_error(self): class FailingRunner: - def __call__(self, *args): + async def __call__(self, *args): raise KeyError("bulk-size missing") def __str__(self): @@ -990,24 +1057,25 @@ def __str__(self): runner = FailingRunner() with self.assertRaises(exceptions.SystemSetupError) as ctx: - driver.execute_single(self.context_managed(runner), es, params) + await driver.execute_single(self.context_managed(runner), es, params) self.assertEqual( "Cannot execute [failing_mock_runner]. Provided parameters are: ['bulk', 'mode']. Error: ['bulk-size missing'].", ctx.exception.args[0]) -class ProfilerTests(TestCase): - def test_profiler_is_a_transparent_wrapper(self): +class AsyncProfilerTests(TestCase): + @run_async + async def test_profiler_is_a_transparent_wrapper(self): import time - def f(x): - time.sleep(x) + async def f(x): + await asyncio.sleep(x) return x * 2 - profiler = driver.Profiler(f, 0, "sleep-operation") + profiler = driver.AsyncProfiler(f) start = time.perf_counter() # this should take roughly 1 second and should return something - return_value = profiler(1) + return_value = await profiler(1) end = time.perf_counter() self.assertEqual(2, return_value) duration = end - start diff --git a/tests/driver/runner_test.py b/tests/driver/runner_test.py index f6b3574c6..6be984df8 100644 --- a/tests/driver/runner_test.py +++ b/tests/driver/runner_test.py @@ -16,23 +16,24 @@ # under the License. import io +import json import random import unittest.mock as mock from unittest import TestCase import elasticsearch -import pytest from esrally import exceptions from esrally.driver import runner +from tests import run_async, as_future class BaseUnitTestContextManagerRunner: - def __enter__(self): + async def __aenter__(self): self.fp = io.StringIO("many\nlines\nin\na\nfile") return self - def __exit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type, exc_val, exc_tb): self.fp.close() return False @@ -41,48 +42,53 @@ class RegisterRunnerTests(TestCase): def tearDown(self): runner.remove_runner("unit_test") - def test_runner_function_should_be_wrapped(self): - def runner_function(*args): + @run_async + async def test_runner_function_should_be_wrapped(self): + async def runner_function(*args): return args - runner.register_runner(operation_type="unit_test", runner=runner_function) + runner.register_runner(operation_type="unit_test", runner=runner_function, async_runner=True) returned_runner = runner.runner_for("unit_test") self.assertIsInstance(returned_runner, runner.NoCompletion) self.assertEqual("user-defined runner for [runner_function]", repr(returned_runner)) - self.assertEqual(("default_client", "param"), returned_runner({"default": "default_client", "other": "other_client"}, "param")) + self.assertEqual(("default_client", "param"), + await returned_runner({"default": "default_client", "other": "other_client"}, "param")) - def test_single_cluster_runner_class_with_context_manager_should_be_wrapped_with_context_manager_enabled(self): + @run_async + async def test_single_cluster_runner_class_with_context_manager_should_be_wrapped_with_context_manager_enabled(self): class UnitTestSingleClusterContextManagerRunner(BaseUnitTestContextManagerRunner): - def __call__(self, *args): + async def __call__(self, *args): return args def __str__(self): return "UnitTestSingleClusterContextManagerRunner" test_runner = UnitTestSingleClusterContextManagerRunner() - runner.register_runner(operation_type="unit_test", runner=test_runner) + runner.register_runner(operation_type="unit_test", runner=test_runner, async_runner=True) returned_runner = runner.runner_for("unit_test") self.assertIsInstance(returned_runner, runner.NoCompletion) self.assertEqual("user-defined context-manager enabled runner for [UnitTestSingleClusterContextManagerRunner]", repr(returned_runner)) # test that context_manager functionality gets preserved after wrapping - with returned_runner: - self.assertEqual(("default_client", "param"), returned_runner({"default": "default_client", "other": "other_client"}, "param")) + async with returned_runner: + self.assertEqual(("default_client", "param"), + await returned_runner({"default": "default_client", "other": "other_client"}, "param")) # check that the context manager interface of our inner runner has been respected. self.assertTrue(test_runner.fp.closed) - def test_multi_cluster_runner_class_with_context_manager_should_be_wrapped_with_context_manager_enabled(self): + @run_async + async def test_multi_cluster_runner_class_with_context_manager_should_be_wrapped_with_context_manager_enabled(self): class UnitTestMultiClusterContextManagerRunner(BaseUnitTestContextManagerRunner): multi_cluster = True - def __call__(self, *args): + async def __call__(self, *args): return args def __str__(self): return "UnitTestMultiClusterContextManagerRunner" test_runner = UnitTestMultiClusterContextManagerRunner() - runner.register_runner(operation_type="unit_test", runner=test_runner) + runner.register_runner(operation_type="unit_test", runner=test_runner, async_runner=True) returned_runner = runner.runner_for("unit_test") self.assertIsInstance(returned_runner, runner.NoCompletion) self.assertEqual("user-defined context-manager enabled runner for [UnitTestMultiClusterContextManagerRunner]", @@ -90,51 +96,132 @@ def __str__(self): # test that context_manager functionality gets preserved after wrapping all_clients = {"default": "default_client", "other": "other_client"} - with returned_runner: - self.assertEqual((all_clients, "param1", "param2"), returned_runner(all_clients, "param1", "param2")) + async with returned_runner: + self.assertEqual((all_clients, "param1", "param2"), await returned_runner(all_clients, "param1", "param2")) # check that the context manager interface of our inner runner has been respected. self.assertTrue(test_runner.fp.closed) - def test_single_cluster_runner_class_should_be_wrapped(self): + @run_async + async def test_single_cluster_runner_class_should_be_wrapped(self): class UnitTestSingleClusterRunner: - def __call__(self, *args): + async def __call__(self, *args): return args def __str__(self): return "UnitTestSingleClusterRunner" test_runner = UnitTestSingleClusterRunner() - runner.register_runner(operation_type="unit_test", runner=test_runner) + runner.register_runner(operation_type="unit_test", runner=test_runner, async_runner=True) returned_runner = runner.runner_for("unit_test") self.assertIsInstance(returned_runner, runner.NoCompletion) self.assertEqual("user-defined runner for [UnitTestSingleClusterRunner]", repr(returned_runner)) - self.assertEqual(("default_client", "param"), returned_runner({"default": "default_client", "other": "other_client"}, "param")) + self.assertEqual(("default_client", "param"), + await returned_runner({"default": "default_client", "other": "other_client"}, "param")) - def test_multi_cluster_runner_class_should_be_wrapped(self): + @run_async + async def test_multi_cluster_runner_class_should_be_wrapped(self): class UnitTestMultiClusterRunner: multi_cluster = True - def __call__(self, *args): + async def __call__(self, *args): return args def __str__(self): return "UnitTestMultiClusterRunner" test_runner = UnitTestMultiClusterRunner() - runner.register_runner(operation_type="unit_test", runner=test_runner) + runner.register_runner(operation_type="unit_test", runner=test_runner, async_runner=True) returned_runner = runner.runner_for("unit_test") self.assertIsInstance(returned_runner, runner.NoCompletion) self.assertEqual("user-defined runner for [UnitTestMultiClusterRunner]", repr(returned_runner)) all_clients = {"default": "default_client", "other": "other_client"} - self.assertEqual((all_clients, "some_param"), returned_runner(all_clients, "some_param")) + self.assertEqual((all_clients, "some_param"), await returned_runner(all_clients, "some_param")) + + +class SelectiveJsonParserTests(TestCase): + def doc_as_text(self, doc): + return io.StringIO(json.dumps(doc)) + + def test_parse_all_expected(self): + doc = self.doc_as_text({ + "title": "Hello", + "meta": { + "length": 100, + "date": { + "year": 2000 + } + } + }) + + parsed = runner.parse(doc, [ + # simple property + "title", + # a nested property + "meta.date.year", + # ignores unknown properties + "meta.date.month" + ]) + + self.assertEqual("Hello", parsed.get("title")) + self.assertEqual(2000, parsed.get("meta.date.year")) + self.assertNotIn("meta.date.month", parsed) + + def test_list_length(self): + doc = self.doc_as_text({ + "title": "Hello", + "meta": { + "length": 100, + "date": { + "year": 2000 + } + }, + "authors": ["George", "Harry"], + "readers": [ + { + "name": "Tom", + "age": 14 + }, + { + "name": "Bob", + "age": 17 + }, + { + "name": "Alice", + "age": 22 + } + ], + "supporters": [] + }) + + parsed = runner.parse(doc, [ + # simple property + "title", + # a nested property + "meta.date.year", + # ignores unknown properties + "meta.date.month" + ], ["authors", "readers", "supporters"]) + + self.assertEqual("Hello", parsed.get("title")) + self.assertEqual(2000, parsed.get("meta.date.year")) + self.assertNotIn("meta.date.month", parsed) + + # lists + self.assertFalse(parsed.get("authors")) + self.assertFalse(parsed.get("readers")) + self.assertTrue(parsed.get("supporters")) class BulkIndexRunnerTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_bulk_index_missing_params(self, es): - es.bulk.return_value = { - "errors": False + @run_async + async def test_bulk_index_missing_params(self, es): + bulk_response = { + "errors": False, + "took": 8 } + es.bulk.return_value = as_future(io.StringIO(json.dumps(bulk_response))) + bulk = runner.BulkIndex() bulk_params = { @@ -147,15 +234,19 @@ def test_bulk_index_missing_params(self, es): } with self.assertRaises(exceptions.DataError) as ctx: - bulk(es, bulk_params) + await bulk(es, bulk_params) self.assertEqual("Parameter source for operation 'bulk-index' did not provide the mandatory parameter 'action-metadata-present'. " "Please add it to your parameter source.", ctx.exception.args[0]) @mock.patch("elasticsearch.Elasticsearch") - def test_bulk_index_success_with_metadata(self, es): - es.bulk.return_value = { - "errors": False + @run_async + async def test_bulk_index_success_with_metadata(self, es): + bulk_response = { + "errors": False, + "took": 8 } + es.bulk.return_value = as_future(io.StringIO(json.dumps(bulk_response))) + bulk = runner.BulkIndex() bulk_params = { @@ -169,9 +260,9 @@ def test_bulk_index_success_with_metadata(self, es): "bulk-size": 3 } - result = bulk(es, bulk_params) + result = await bulk(es, bulk_params) - self.assertIsNone(result["took"]) + self.assertEqual(8, result["took"]) self.assertIsNone(result["index"]) self.assertEqual(3, result["weight"]) self.assertEqual(3, result["bulk-size"]) @@ -183,10 +274,13 @@ def test_bulk_index_success_with_metadata(self, es): es.bulk.assert_called_with(body=bulk_params["body"], params={}) @mock.patch("elasticsearch.Elasticsearch") - def test_bulk_index_success_without_metadata_with_doc_type(self, es): - es.bulk.return_value = { - "errors": False + @run_async + async def test_bulk_index_success_without_metadata_with_doc_type(self, es): + bulk_response = { + "errors": False, + "took": 8 } + es.bulk.return_value = as_future(io.StringIO(json.dumps(bulk_response))) bulk = runner.BulkIndex() bulk_params = { @@ -199,9 +293,9 @@ def test_bulk_index_success_without_metadata_with_doc_type(self, es): "type": "_doc" } - result = bulk(es, bulk_params) + result = await bulk(es, bulk_params) - self.assertIsNone(result["took"]) + self.assertEqual(8, result["took"]) self.assertEqual("test-index", result["index"]) self.assertEqual(3, result["weight"]) self.assertEqual(3, result["bulk-size"]) @@ -213,10 +307,13 @@ def test_bulk_index_success_without_metadata_with_doc_type(self, es): es.bulk.assert_called_with(body=bulk_params["body"], index="test-index", doc_type="_doc", params={}) @mock.patch("elasticsearch.Elasticsearch") - def test_bulk_index_success_without_metadata_and_without_doc_type(self, es): - es.bulk.return_value = { - "errors": False + @run_async + async def test_bulk_index_success_without_metadata_and_without_doc_type(self, es): + bulk_response = { + "errors": False, + "took": 8 } + es.bulk.return_value = as_future(io.StringIO(json.dumps(bulk_response))) bulk = runner.BulkIndex() bulk_params = { @@ -228,9 +325,9 @@ def test_bulk_index_success_without_metadata_and_without_doc_type(self, es): "index": "test-index" } - result = bulk(es, bulk_params) + result = await bulk(es, bulk_params) - self.assertIsNone(result["took"]) + self.assertEqual(8, result["took"]) self.assertEqual("test-index", result["index"]) self.assertEqual(3, result["weight"]) self.assertEqual(3, result["bulk-size"]) @@ -242,8 +339,9 @@ def test_bulk_index_success_without_metadata_and_without_doc_type(self, es): es.bulk.assert_called_with(body=bulk_params["body"], index="test-index", doc_type=None, params={}) @mock.patch("elasticsearch.Elasticsearch") - def test_bulk_index_error(self, es): - es.bulk.return_value = { + @run_async + async def test_bulk_index_error(self, es): + bulk_response = { "took": 5, "errors": True, "items": [ @@ -279,6 +377,9 @@ def test_bulk_index_error(self, es): }, ] } + + es.bulk.return_value = as_future(io.StringIO(json.dumps(bulk_response))) + bulk = runner.BulkIndex() bulk_params = { @@ -293,7 +394,7 @@ def test_bulk_index_error(self, es): "index": "test" } - result = bulk(es, bulk_params) + result = await bulk(es, bulk_params) self.assertEqual("test", result["index"]) self.assertEqual(5, result["took"]) @@ -307,8 +408,9 @@ def test_bulk_index_error(self, es): es.bulk.assert_called_with(body=bulk_params["body"], params={}) @mock.patch("elasticsearch.Elasticsearch") - def test_bulk_index_error_no_shards(self, es): - es.bulk.return_value = { + @run_async + async def test_bulk_index_error_no_shards(self, es): + bulk_response = { "took": 20, "errors": True, "items": [ @@ -341,6 +443,9 @@ def test_bulk_index_error_no_shards(self, es): } ] } + + es.bulk.return_value = as_future(io.StringIO(json.dumps(bulk_response))) + bulk = runner.BulkIndex() bulk_params = { @@ -356,7 +461,7 @@ def test_bulk_index_error_no_shards(self, es): "index": "test" } - result = bulk(es, bulk_params) + result = await bulk(es, bulk_params) self.assertEqual("test", result["index"]) self.assertEqual(20, result["took"]) @@ -370,8 +475,9 @@ def test_bulk_index_error_no_shards(self, es): es.bulk.assert_called_with(body=bulk_params["body"], params={}) @mock.patch("elasticsearch.Elasticsearch") - def test_mixed_bulk_with_simple_stats(self, es): - es.bulk.return_value = { + @run_async + async def test_mixed_bulk_with_simple_stats(self, es): + bulk_response = { "took": 30, "ingest_took": 20, "errors": True, @@ -444,6 +550,7 @@ def test_mixed_bulk_with_simple_stats(self, es): } ] } + es.bulk.return_value = as_future(io.StringIO(json.dumps(bulk_response))) bulk = runner.BulkIndex() bulk_params = { @@ -461,11 +568,11 @@ def test_mixed_bulk_with_simple_stats(self, es): "index": "test" } - result = bulk(es, bulk_params) + result = await bulk(es, bulk_params) self.assertEqual("test", result["index"]) self.assertEqual(30, result["took"]) - self.assertEqual(20, result["ingest_took"]) + self.assertNotIn("ingest_took", result, "ingest_took is not extracted with simple stats") self.assertEqual(4, result["weight"]) self.assertEqual(4, result["bulk-size"]) self.assertEqual("docs", result["unit"]) @@ -475,13 +582,10 @@ def test_mixed_bulk_with_simple_stats(self, es): es.bulk.assert_called_with(body=bulk_params["body"], params={}) - es.bulk.return_value.pop("ingest_took") - result = bulk(es, bulk_params) - self.assertNotIn("ingest_took", result) - @mock.patch("elasticsearch.Elasticsearch") - def test_mixed_bulk_with_detailed_stats_body_as_string(self, es): - es.bulk.return_value = { + @run_async + async def test_mixed_bulk_with_detailed_stats_body_as_string(self, es): + es.bulk.return_value = as_future({ "took": 30, "ingest_took": 20, "errors": True, @@ -587,7 +691,7 @@ def test_mixed_bulk_with_detailed_stats_body_as_string(self, es): } } ] - } + }) bulk = runner.BulkIndex() bulk_params = { @@ -609,7 +713,7 @@ def test_mixed_bulk_with_detailed_stats_body_as_string(self, es): "index": "test" } - result = bulk(es, bulk_params) + result = await bulk(es, bulk_params) self.assertEqual("test", result["index"]) self.assertEqual(30, result["took"]) @@ -666,13 +770,14 @@ def test_mixed_bulk_with_detailed_stats_body_as_string(self, es): es.bulk.assert_called_with(body=bulk_params["body"], params={}) - es.bulk.return_value.pop("ingest_took") - result = bulk(es, bulk_params) + es.bulk.return_value.result().pop("ingest_took") + result = await bulk(es, bulk_params) self.assertNotIn("ingest_took", result) @mock.patch("elasticsearch.Elasticsearch") - def test_simple_bulk_with_detailed_stats_body_as_list(self, es): - es.bulk.return_value = { + @run_async + async def test_simple_bulk_with_detailed_stats_body_as_list(self, es): + es.bulk.return_value = as_future({ "took": 30, "ingest_took": 20, "errors": False, @@ -695,7 +800,7 @@ def test_simple_bulk_with_detailed_stats_body_as_list(self, es): } } ] - } + }) bulk = runner.BulkIndex() bulk_params = { @@ -707,7 +812,7 @@ def test_simple_bulk_with_detailed_stats_body_as_list(self, es): "index": "test" } - result = bulk(es, bulk_params) + result = await bulk(es, bulk_params) self.assertEqual("test", result["index"]) self.assertEqual(30, result["took"]) @@ -740,13 +845,14 @@ def test_simple_bulk_with_detailed_stats_body_as_list(self, es): es.bulk.assert_called_with(body=bulk_params["body"], params={}) - es.bulk.return_value.pop("ingest_took") - result = bulk(es, bulk_params) + es.bulk.return_value.result().pop("ingest_took") + result = await bulk(es, bulk_params) self.assertNotIn("ingest_took", result) @mock.patch("elasticsearch.Elasticsearch") - def test_simple_bulk_with_detailed_stats_body_as_unrecognized_type(self, es): - es.bulk.return_value = { + @run_async + async def test_simple_bulk_with_detailed_stats_body_as_unrecognized_type(self, es): + es.bulk.return_value = as_future({ "took": 30, "ingest_took": 20, "errors": False, @@ -769,7 +875,7 @@ def test_simple_bulk_with_detailed_stats_body_as_unrecognized_type(self, es): } } ] - } + }) bulk = runner.BulkIndex() bulk_params = { @@ -784,47 +890,59 @@ def test_simple_bulk_with_detailed_stats_body_as_unrecognized_type(self, es): } with self.assertRaisesRegex(exceptions.DataError, "bulk body is neither string nor list"): - bulk(es, bulk_params) + await bulk(es, bulk_params) es.bulk.assert_called_with(body=bulk_params["body"], params={}) class ForceMergeRunnerTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_force_merge_with_defaults(self, es): + @run_async + async def test_force_merge_with_defaults(self, es): + es.indices.forcemerge.return_value = as_future() force_merge = runner.ForceMerge() - force_merge(es, params={"index" : "_all"}) + await force_merge(es, params={"index" : "_all"}) es.indices.forcemerge.assert_called_once_with(index="_all", request_timeout=None) @mock.patch("elasticsearch.Elasticsearch") - def test_force_merge_override_request_timeout(self, es): + @run_async + async def test_force_merge_override_request_timeout(self, es): + es.indices.forcemerge.return_value = as_future() + force_merge = runner.ForceMerge() - force_merge(es, params={"index" : "_all", "request-timeout": 50000}) + await force_merge(es, params={"index" : "_all", "request-timeout": 50000}) es.indices.forcemerge.assert_called_once_with(index="_all", request_timeout=50000) @mock.patch("elasticsearch.Elasticsearch") - def test_force_merge_with_params(self, es): + @run_async + async def test_force_merge_with_params(self, es): + es.indices.forcemerge.return_value = as_future() + force_merge = runner.ForceMerge() - force_merge(es, params={"index" : "_all", "max-num-segments": 1, "request-timeout": 50000}) + await force_merge(es, params={"index" : "_all", "max-num-segments": 1, "request-timeout": 50000}) es.indices.forcemerge.assert_called_once_with(index="_all", max_num_segments=1, request_timeout=50000) @mock.patch("elasticsearch.Elasticsearch") - def test_optimize_with_defaults(self, es): - es.indices.forcemerge.side_effect = elasticsearch.TransportError(400, "Bad Request") + @run_async + async def test_optimize_with_defaults(self, es): + es.indices.forcemerge.side_effect = as_future(exception=elasticsearch.TransportError(400, "Bad Request")) + es.transport.perform_request.return_value = as_future() force_merge = runner.ForceMerge() - force_merge(es, params={}) + await force_merge(es, params={}) es.transport.perform_request.assert_called_once_with("POST", "/_optimize", params={"request_timeout": None}) @mock.patch("elasticsearch.Elasticsearch") - def test_optimize_with_params(self, es): - es.indices.forcemerge.side_effect = elasticsearch.TransportError(400, "Bad Request") + @run_async + async def test_optimize_with_params(self, es): + es.indices.forcemerge.side_effect = as_future(exception=elasticsearch.TransportError(400, "Bad Request")) + es.transport.perform_request.return_value = as_future() force_merge = runner.ForceMerge() - force_merge(es, params={"max-num-segments": 3, "request-timeout": 17000}) + await force_merge(es, params={"max-num-segments": 3, "request-timeout": 17000}) es.transport.perform_request.assert_called_once_with("POST", "/_optimize?max_num_segments=3", params={"request_timeout": 17000}) @@ -832,9 +950,11 @@ def test_optimize_with_params(self, es): class IndicesStatsRunnerTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_indices_stats_without_parameters(self, es): + @run_async + async def test_indices_stats_without_parameters(self, es): + es.indices.stats.return_value = as_future({}) indices_stats = runner.IndicesStats() - result = indices_stats(es, params={}) + result = await indices_stats(es, params={}) self.assertEqual(1, result["weight"]) self.assertEqual("ops", result["unit"]) self.assertTrue(result["success"]) @@ -842,8 +962,9 @@ def test_indices_stats_without_parameters(self, es): es.indices.stats.assert_called_once_with(index="_all", metric="_all") @mock.patch("elasticsearch.Elasticsearch") - def test_indices_stats_with_failed_condition(self, es): - es.indices.stats.return_value = { + @run_async + async def test_indices_stats_with_failed_condition(self, es): + es.indices.stats.return_value = as_future({ "_all": { "total": { "merges": { @@ -852,11 +973,11 @@ def test_indices_stats_with_failed_condition(self, es): } } } - } + }) indices_stats = runner.IndicesStats() - result = indices_stats(es, params={ + result = await indices_stats(es, params={ "index": "logs-*", "condition": { "path": "_all.total.merges.current", @@ -875,8 +996,9 @@ def test_indices_stats_with_failed_condition(self, es): es.indices.stats.assert_called_once_with(index="logs-*", metric="_all") @mock.patch("elasticsearch.Elasticsearch") - def test_indices_stats_with_successful_condition(self, es): - es.indices.stats.return_value = { + @run_async + async def test_indices_stats_with_successful_condition(self, es): + es.indices.stats.return_value = as_future({ "_all": { "total": { "merges": { @@ -885,11 +1007,11 @@ def test_indices_stats_with_successful_condition(self, es): } } } - } + }) indices_stats = runner.IndicesStats() - result = indices_stats(es, params={ + result = await indices_stats(es, params={ "index": "logs-*", "condition": { "path": "_all.total.merges.current", @@ -908,8 +1030,9 @@ def test_indices_stats_with_successful_condition(self, es): es.indices.stats.assert_called_once_with(index="logs-*", metric="_all") @mock.patch("elasticsearch.Elasticsearch") - def test_indices_stats_with_non_existing_path(self, es): - es.indices.stats.return_value = { + @run_async + async def test_indices_stats_with_non_existing_path(self, es): + es.indices.stats.return_value = as_future({ "indices": { "total": { "docs": { @@ -917,11 +1040,11 @@ def test_indices_stats_with_non_existing_path(self, es): } } } - } + }) indices_stats = runner.IndicesStats() - result = indices_stats(es, params={ + result = await indices_stats(es, params={ "index": "logs-*", "condition": { # non-existing path @@ -943,8 +1066,9 @@ def test_indices_stats_with_non_existing_path(self, es): class QueryRunnerTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_query_match_only_request_body_defined(self, es): - es.search.return_value = { + @run_async + async def test_query_match_only_request_body_defined(self, es): + search_response = { "timed_out": False, "took": 5, "hits": { @@ -954,18 +1078,20 @@ def test_query_match_only_request_body_defined(self, es): }, "hits": [ { - "some-doc-1" + "title": "some-doc-1" }, { - "some-doc-2" + "title": "some-doc-2" } ] } } + es.search.return_value = as_future(io.StringIO(json.dumps(search_response))) query_runner = runner.Query() params = { + "detailed-results": True, "cache": True, "body": { "query": { @@ -974,8 +1100,8 @@ def test_query_match_only_request_body_defined(self, es): } } - with query_runner: - result = query_runner(es, params) + async with query_runner: + result = await query_runner(es, params) self.assertEqual(1, result["weight"]) self.assertEqual("ops", result["unit"]) @@ -990,10 +1116,12 @@ def test_query_match_only_request_body_defined(self, es): body=params["body"], params={"request_cache": "true"} ) + es.clear_scroll.assert_not_called() @mock.patch("elasticsearch.Elasticsearch") - def test_query_match_using_request_params(self, es): - es.search.return_value = { + @run_async + async def test_query_match_using_request_params(self, es): + response = { "timed_out": False, "took": 62, "hits": { @@ -1003,27 +1131,29 @@ def test_query_match_using_request_params(self, es): }, "hits": [ { - "some-doc-1" + "title": "some-doc-1" }, { - "some-doc-2" + "title": "some-doc-2" } ] } } + es.search.return_value = as_future(io.StringIO(json.dumps(response))) query_runner = runner.Query() params = { "cache": False, + "detailed-results": True, "body": None, "request-params": { "q": "user:kimchy" } } - with query_runner: - result = query_runner(es, params) + async with query_runner: + result = await query_runner(es, params) self.assertEqual(1, result["weight"]) self.assertEqual("ops", result["unit"]) @@ -1041,29 +1171,84 @@ def test_query_match_using_request_params(self, es): "q": "user:kimchy" } ) + es.clear_scroll.assert_not_called() + + @mock.patch("elasticsearch.Elasticsearch") + @run_async + async def test_query_no_detailed_results(self, es): + response = { + "timed_out": False, + "took": 62, + "hits": { + "total": { + "value": 2, + "relation": "eq" + }, + "hits": [ + { + "title": "some-doc-1" + }, + { + "title": "some-doc-2" + } + + ] + } + } + es.search.return_value = as_future(io.StringIO(json.dumps(response))) + + query_runner = runner.Query() + params = { + "body": None, + "request-params": { + "q": "user:kimchy" + }, + "detailed-results": False + } + + async with query_runner: + result = await query_runner(es, params) + + self.assertEqual(1, result["weight"]) + self.assertEqual("ops", result["unit"]) + self.assertNotIn("hits", result) + self.assertNotIn("hits_relation", result) + self.assertNotIn("timed_out", result) + self.assertNotIn("took", result) + self.assertNotIn("error-type", result) + + es.search.assert_called_once_with( + index="_all", + body=params["body"], + params={"q": "user:kimchy"} + ) + es.clear_scroll.assert_not_called() @mock.patch("elasticsearch.Elasticsearch") - def test_query_hits_total_as_number(self, es): - es.search.return_value = { + @run_async + async def test_query_hits_total_as_number(self, es): + search_response = { "timed_out": False, "took": 5, "hits": { "total": 2, "hits": [ { - "some-doc-1" + "title": "some-doc-1" }, { - "some-doc-2" + "title": "some-doc-2" } ] } } + es.search.return_value = as_future(io.StringIO(json.dumps(search_response))) query_runner = runner.Query() params = { "cache": True, + "detailed-results": True, "body": { "query": { "match_all": {} @@ -1071,8 +1256,8 @@ def test_query_hits_total_as_number(self, es): } } - with query_runner: - result = query_runner(es, params) + async with query_runner: + result = await query_runner(es, params) self.assertEqual(1, result["weight"]) self.assertEqual("ops", result["unit"]) @@ -1089,10 +1274,12 @@ def test_query_hits_total_as_number(self, es): "request_cache": "true" } ) + es.clear_scroll.assert_not_called() @mock.patch("elasticsearch.Elasticsearch") - def test_query_match_all(self, es): - es.search.return_value = { + @run_async + async def test_query_match_all(self, es): + search_response = { "timed_out": False, "took": 5, "hits": { @@ -1102,19 +1289,21 @@ def test_query_match_all(self, es): }, "hits": [ { - "some-doc-1" + "title": "some-doc-1" }, { - "some-doc-2" + "title": "some-doc-2" } ] } } + es.search.return_value = as_future(io.StringIO(json.dumps(search_response))) query_runner = runner.Query() params = { "index": "unittest", + "detailed-results": True, "cache": None, "body": { "query": { @@ -1123,8 +1312,8 @@ def test_query_match_all(self, es): } } - with query_runner: - result = query_runner(es, params) + async with query_runner: + result = await query_runner(es, params) self.assertEqual(1, result["weight"]) self.assertEqual("ops", result["unit"]) @@ -1139,10 +1328,12 @@ def test_query_match_all(self, es): body=params["body"], params={} ) + es.clear_scroll.assert_not_called() @mock.patch("elasticsearch.Elasticsearch") - def test_query_match_all_doc_type_fallback(self, es): - es.transport.perform_request.return_value = { + @run_async + async def test_query_match_all_doc_type_fallback(self, es): + search_response = { "timed_out": False, "took": 5, "hits": { @@ -1152,20 +1343,23 @@ def test_query_match_all_doc_type_fallback(self, es): }, "hits": [ { - "some-doc-1" + "title": "some-doc-1" }, { - "some-doc-2" + "title": "some-doc-2" } ] } } + es.transport.perform_request.return_value = as_future(io.StringIO(json.dumps(search_response))) + query_runner = runner.Query() params = { "index": "unittest", "type": "type", + "detailed-results": True, "cache": None, "body": { "query": { @@ -1174,8 +1368,8 @@ def test_query_match_all_doc_type_fallback(self, es): } } - with query_runner: - result = query_runner(es, params) + async with query_runner: + result = await query_runner(es, params) self.assertEqual(1, result["weight"]) self.assertEqual("ops", result["unit"]) @@ -1190,31 +1384,34 @@ def test_query_match_all_doc_type_fallback(self, es): body=params["body"], params={} ) + es.clear_scroll.assert_not_called() @mock.patch("elasticsearch.Elasticsearch") - def test_scroll_query_only_one_page(self, es): + @run_async + async def test_scroll_query_only_one_page(self, es): # page 1 - es.search.return_value = { + search_response = { "_scroll_id": "some-scroll-id", "took": 4, "timed_out": False, "hits": { + "total": { + "value": 2, + "relation": "eq" + }, "hits": [ { - "some-doc-1" + "title": "some-doc-1" }, { - "some-doc-2" + "title": "some-doc-2" } ] } } - es.transport.perform_request.side_effect = [ - # delete scroll id response - { - "acknowledged": True - } - ] + + es.search.return_value = as_future(io.StringIO(json.dumps(search_response))) + es.clear_scroll.return_value = as_future(io.StringIO('{"acknowledged": true}')) query_runner = runner.Query() @@ -1230,8 +1427,8 @@ def test_scroll_query_only_one_page(self, es): } } - with query_runner: - results = query_runner(es, params) + async with query_runner: + results = await query_runner(es, params) self.assertEqual(1, results["weight"]) self.assertEqual(1, results["pages"]) @@ -1252,31 +1449,34 @@ def test_scroll_query_only_one_page(self, es): "request_cache": "true" } ) + es.clear_scroll.assert_called_once_with(body={"scroll_id": ["some-scroll-id"]}) @mock.patch("elasticsearch.Elasticsearch") - def test_scroll_query_no_request_cache(self, es): + @run_async + async def test_scroll_query_no_request_cache(self, es): # page 1 - es.search.return_value = { + search_response = { "_scroll_id": "some-scroll-id", "took": 4, "timed_out": False, "hits": { + "total": { + "value": 2, + "relation": "eq" + }, "hits": [ { - "some-doc-1" + "title": "some-doc-1" }, { - "some-doc-2" + "title": "some-doc-2" } ] } } - es.transport.perform_request.side_effect = [ - # delete scroll id response - { - "acknowledged": True - } - ] + + es.search.return_value = as_future(io.StringIO(json.dumps(search_response))) + es.clear_scroll.return_value = as_future(io.StringIO('{"acknowledged": true}')) query_runner = runner.Query() @@ -1291,8 +1491,8 @@ def test_scroll_query_no_request_cache(self, es): } } - with query_runner: - results = query_runner(es, params) + async with query_runner: + results = await query_runner(es, params) self.assertEqual(1, results["weight"]) self.assertEqual(1, results["pages"]) @@ -1311,31 +1511,34 @@ def test_scroll_query_no_request_cache(self, es): sort='_doc', params={} ) + es.clear_scroll.assert_called_once_with(body={"scroll_id": ["some-scroll-id"]}) @mock.patch("elasticsearch.Elasticsearch") - def test_scroll_query_only_one_page_only_request_body_defined(self, es): + @run_async + async def test_scroll_query_only_one_page_only_request_body_defined(self, es): # page 1 - es.search.return_value = { + search_response = { "_scroll_id": "some-scroll-id", "took": 4, "timed_out": False, "hits": { + "total": { + "value": 2, + "relation": "eq" + }, "hits": [ { - "some-doc-1" + "title": "some-doc-1" }, { - "some-doc-2" + "title": "some-doc-2" } ] } } - es.transport.perform_request.side_effect = [ - # delete scroll id response - { - "acknowledged": True - } - ] + + es.search.return_value = as_future(io.StringIO(json.dumps(search_response))) + es.clear_scroll.return_value = as_future(io.StringIO('{"acknowledged": true}')) query_runner = runner.Query() @@ -1349,8 +1552,8 @@ def test_scroll_query_only_one_page_only_request_body_defined(self, es): } } - with query_runner: - results = query_runner(es, params) + async with query_runner: + results = await query_runner(es, params) self.assertEqual(1, results["weight"]) self.assertEqual(1, results["pages"]) @@ -1361,106 +1564,64 @@ def test_scroll_query_only_one_page_only_request_body_defined(self, es): self.assertFalse(results["timed_out"]) self.assertFalse("error-type" in results) + es.search.assert_called_once_with( + index="_all", + body=params["body"], + scroll="10s", + size=100, + sort='_doc', + params={} + ) + + es.clear_scroll.assert_called_once_with(body={"scroll_id": ["some-scroll-id"]}) + @mock.patch("elasticsearch.Elasticsearch") - def test_scroll_query_with_explicit_number_of_pages(self, es): + @run_async + async def test_scroll_query_with_explicit_number_of_pages(self, es): # page 1 - es.search.return_value = { + search_response = { "_scroll_id": "some-scroll-id", "timed_out": False, "took": 54, "hits": { + "total": { + # includes all hits across all pages + "value": 3, + "relation": "eq" + }, "hits": [ { - "some-doc-1" + "title": "some-doc-1" }, { - "some-doc-2" + "title": "some-doc-2" } ] } } - es.scroll.side_effect = [ - # page 2 - { - "_scroll_id": "some-scroll-id", - "timed_out": True, - "took": 25, - "hits": { - "hits": [ - { - "some-doc-3" - } - ] - } - }, - # delete scroll id response - { - "acknowledged": True - } - ] - - query_runner = runner.Query() - - params = { - "pages": 2, - "results-per-page": 100, - "index": "unittest", - "cache": False, - "body": { - "query": { - "match_all": {} - } - } - } - - with query_runner: - results = query_runner(es, params) - - self.assertEqual(2, results["weight"]) - self.assertEqual(2, results["pages"]) - self.assertEqual(3, results["hits"]) - self.assertEqual("eq", results["hits_relation"]) - self.assertEqual(79, results["took"]) - self.assertEqual("pages", results["unit"]) - self.assertTrue(results["timed_out"]) - self.assertFalse("error-type" in results) + es.search.return_value = as_future(io.StringIO(json.dumps(search_response))) - @mock.patch("elasticsearch.Elasticsearch") - def test_scroll_query_early_termination(self, es): - # page 1 - es.search.return_value = { + # page 2 + scroll_response = { "_scroll_id": "some-scroll-id", - "timed_out": False, - "took": 53, + "timed_out": True, + "took": 25, "hits": { "hits": [ { - "some-doc-1" + "title": "some-doc-3" } ] } } - es.scroll.side_effect = [ - # page 2 has no results - { - "_scroll_id": "some-scroll-id", - "timed_out": False, - "took": 2, - "hits": { - "hits": [] - } - }, - # delete scroll id response - { - "acknowledged": True - } - ] + es.scroll.return_value = as_future(io.StringIO(json.dumps(scroll_response))) + es.clear_scroll.return_value = as_future(io.StringIO('{"acknowledged": true}')) query_runner = runner.Query() params = { - "pages": 5, - "results-per-page": 100, + "pages": 2, + "results-per-page": 2, "index": "unittest", "cache": False, "body": { @@ -1470,46 +1631,44 @@ def test_scroll_query_early_termination(self, es): } } - with query_runner: - results = query_runner(es, params) + async with query_runner: + results = await query_runner(es, params) self.assertEqual(2, results["weight"]) self.assertEqual(2, results["pages"]) - self.assertEqual(1, results["hits"]) + self.assertEqual(3, results["hits"]) self.assertEqual("eq", results["hits_relation"]) + self.assertEqual(79, results["took"]) self.assertEqual("pages", results["unit"]) - self.assertEqual(55, results["took"]) + self.assertTrue(results["timed_out"]) self.assertFalse("error-type" in results) + es.clear_scroll.assert_called_once_with(body={"scroll_id": ["some-scroll-id"]}) + @mock.patch("elasticsearch.Elasticsearch") - def test_scroll_query_cannot_clear_scroll(self, es): + @run_async + async def test_scroll_query_cannot_clear_scroll(self, es): import elasticsearch # page 1 - es.search.return_value = { + search_response = { "_scroll_id": "some-scroll-id", "timed_out": False, "took": 53, "hits": { + "total": { + "value": 1, + "relation": "eq" + }, "hits": [ { - "some-doc-1" + "title": "some-doc-1" } ] } } - es.scroll.side_effect = [ - # page 2 has no results - { - "_scroll_id": "some-scroll-id", - "timed_out": False, - "took": 2, - "hits": { - "hits": [] - } - }, - # delete scroll id raises an exception - elasticsearch.ConnectionTimeout() - ] + + es.search.return_value = as_future(io.StringIO(json.dumps(search_response))) + es.clear_scroll.return_value = as_future(exception=elasticsearch.ConnectionTimeout()) query_runner = runner.Query() @@ -1525,62 +1684,66 @@ def test_scroll_query_cannot_clear_scroll(self, es): } } - with query_runner: - results = query_runner(es, params) + async with query_runner: + results = await query_runner(es, params) - self.assertEqual(2, results["weight"]) - self.assertEqual(2, results["pages"]) + self.assertEqual(1, results["weight"]) + self.assertEqual(1, results["pages"]) self.assertEqual(1, results["hits"]) self.assertEqual("eq", results["hits_relation"]) self.assertEqual("pages", results["unit"]) - self.assertEqual(55, results["took"]) + self.assertEqual(53, results["took"]) self.assertFalse("error-type" in results) + es.clear_scroll.assert_called_once_with(body={"scroll_id": ["some-scroll-id"]}) + @mock.patch("elasticsearch.Elasticsearch") - def test_scroll_query_request_all_pages(self, es): + @run_async + async def test_scroll_query_request_all_pages(self, es): # page 1 - es.search.return_value = { + search_response = { "_scroll_id": "some-scroll-id", "timed_out": False, "took": 876, "hits": { + "total": { + "value": 4, + "relation": "gte" + }, "hits": [ { - "some-doc-1" + "title": "some-doc-1" }, { - "some-doc-2" + "title": "some-doc-2" }, { - "some-doc-3" + "title": "some-doc-3" }, { - "some-doc-4" + "title": "some-doc-4" } ] } } - es.scroll.side_effect = [ - # page 2 has no results - { - "_scroll_id": "some-scroll-id", - "took": 24, - "timed_out": False, - "hits": { - "hits": [] - } - }, - # delete scroll id response - { - "acknowledged": True + es.search.return_value = as_future(io.StringIO(json.dumps(search_response))) + # page 2 has no results + scroll_response = { + "_scroll_id": "some-scroll-id", + "timed_out": False, + "took": 2, + "hits": { + "hits": [] } - ] + } + es.scroll.return_value = as_future(io.StringIO(json.dumps(scroll_response))) + es.clear_scroll.return_value = as_future(io.StringIO('{"acknowledged": true}')) query_runner = runner.Query() params = { "pages": "all", - "results-per-page": 100, + "results-per-page": 4, "index": "unittest", "cache": False, "body": { @@ -1590,22 +1753,27 @@ def test_scroll_query_request_all_pages(self, es): } } - with query_runner: - results = query_runner(es, params) + async with query_runner: + results = await query_runner(es, params) self.assertEqual(2, results["weight"]) self.assertEqual(2, results["pages"]) self.assertEqual(4, results["hits"]) - self.assertEqual("eq", results["hits_relation"]) - self.assertEqual(900, results["took"]) + self.assertEqual("gte", results["hits_relation"]) + self.assertEqual(878, results["took"]) self.assertEqual("pages", results["unit"]) self.assertFalse(results["timed_out"]) self.assertFalse("error-type" in results) + es.clear_scroll.assert_called_once_with(body={"scroll_id": ["some-scroll-id"]}) + class PutPipelineRunnerTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_create_pipeline(self, es): + @run_async + async def test_create_pipeline(self, es): + es.ingest.put_pipeline.return_value = as_future() + r = runner.PutPipeline() params = { @@ -1623,12 +1791,15 @@ def test_create_pipeline(self, es): } } - r(es, params) + await r(es, params) es.ingest.put_pipeline.assert_called_once_with(id="rename", body=params["body"], master_timeout=None, timeout=None) @mock.patch("elasticsearch.Elasticsearch") - def test_param_body_mandatory(self, es): + @run_async + async def test_param_body_mandatory(self, es): + es.ingest.put_pipeline.return_value = as_future() + r = runner.PutPipeline() params = { @@ -1637,12 +1808,15 @@ def test_param_body_mandatory(self, es): with self.assertRaisesRegex(exceptions.DataError, "Parameter source for operation 'put-pipeline' did not provide the mandatory parameter 'body'. " "Please add it to your parameter source."): - r(es, params) + await r(es, params) self.assertEqual(0, es.ingest.put_pipeline.call_count) @mock.patch("elasticsearch.Elasticsearch") - def test_param_id_mandatory(self, es): + @run_async + async def test_param_id_mandatory(self, es): + es.ingest.put_pipeline.return_value = as_future() + r = runner.PutPipeline() params = { @@ -1651,18 +1825,19 @@ def test_param_id_mandatory(self, es): with self.assertRaisesRegex(exceptions.DataError, "Parameter source for operation 'put-pipeline' did not provide the mandatory parameter 'id'. " "Please add it to your parameter source."): - r(es, params) + await r(es, params) self.assertEqual(0, es.ingest.put_pipeline.call_count) class ClusterHealthRunnerTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_waits_for_expected_cluster_status(self, es): - es.cluster.health.return_value = { + @run_async + async def test_waits_for_expected_cluster_status(self, es): + es.cluster.health.return_value = as_future({ "status": "green", "relocating_shards": 0 - } + }) r = runner.ClusterHealth() params = { @@ -1671,7 +1846,7 @@ def test_waits_for_expected_cluster_status(self, es): } } - result = r(es, params) + result = await r(es, params) self.assertDictEqual({ "weight": 1, @@ -1684,11 +1859,12 @@ def test_waits_for_expected_cluster_status(self, es): es.cluster.health.assert_called_once_with(index=None, params={"wait_for_status": "green"}) @mock.patch("elasticsearch.Elasticsearch") - def test_accepts_better_cluster_status(self, es): - es.cluster.health.return_value = { + @run_async + async def test_accepts_better_cluster_status(self, es): + es.cluster.health.return_value = as_future({ "status": "green", "relocating_shards": 0 - } + }) r = runner.ClusterHealth() params = { @@ -1697,7 +1873,7 @@ def test_accepts_better_cluster_status(self, es): } } - result = r(es, params) + result = await r(es, params) self.assertDictEqual({ "weight": 1, @@ -1710,11 +1886,12 @@ def test_accepts_better_cluster_status(self, es): es.cluster.health.assert_called_once_with(index=None, params={"wait_for_status": "yellow"}) @mock.patch("elasticsearch.Elasticsearch") - def test_rejects_relocating_shards(self, es): - es.cluster.health.return_value = { + @run_async + async def test_rejects_relocating_shards(self, es): + es.cluster.health.return_value = as_future({ "status": "yellow", "relocating_shards": 3 - } + }) r = runner.ClusterHealth() params = { @@ -1725,7 +1902,7 @@ def test_rejects_relocating_shards(self, es): } } - result = r(es, params) + result = await r(es, params) self.assertDictEqual({ "weight": 1, @@ -1739,11 +1916,12 @@ def test_rejects_relocating_shards(self, es): params={"wait_for_status": "red", "wait_for_no_relocating_shards": True}) @mock.patch("elasticsearch.Elasticsearch") - def test_rejects_unknown_cluster_status(self, es): - es.cluster.health.return_value = { + @run_async + async def test_rejects_unknown_cluster_status(self, es): + es.cluster.health.return_value = as_future({ "status": None, "relocating_shards": 0 - } + }) r = runner.ClusterHealth() params = { @@ -1752,7 +1930,7 @@ def test_rejects_unknown_cluster_status(self, es): } } - result = r(es, params) + result = await r(es, params) self.assertDictEqual({ "weight": 1, @@ -1767,7 +1945,10 @@ def test_rejects_unknown_cluster_status(self, es): class CreateIndexRunnerTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_creates_multiple_indices(self, es): + @run_async + async def test_creates_multiple_indices(self, es): + es.indices.create.return_value = as_future() + r = runner.CreateIndex() request_params = { @@ -1782,7 +1963,7 @@ def test_creates_multiple_indices(self, es): "request-params": request_params } - result = r(es, params) + result = await r(es, params) self.assertEqual((2, "ops"), result) @@ -1792,22 +1973,27 @@ def test_creates_multiple_indices(self, es): ]) @mock.patch("elasticsearch.Elasticsearch") - def test_param_indices_mandatory(self, es): + @run_async + async def test_param_indices_mandatory(self, es): + es.indices.create.return_value = as_future() + r = runner.CreateIndex() params = {} with self.assertRaisesRegex(exceptions.DataError, "Parameter source for operation 'create-index' did not provide the mandatory parameter 'indices'. " "Please add it to your parameter source."): - r(es, params) + await r(es, params) self.assertEqual(0, es.indices.create.call_count) class DeleteIndexRunnerTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_deletes_existing_indices(self, es): - es.indices.exists.side_effect = [False, True] + @run_async + async def test_deletes_existing_indices(self, es): + es.indices.exists.side_effect = [as_future(False), as_future(True)] + es.indices.delete.return_value = as_future() r = runner.DeleteIndex() @@ -1816,14 +2002,17 @@ def test_deletes_existing_indices(self, es): "only-if-exists": True } - result = r(es, params) + result = await r(es, params) self.assertEqual((1, "ops"), result) es.indices.delete.assert_called_once_with(index="indexB", params={}) @mock.patch("elasticsearch.Elasticsearch") - def test_deletes_all_indices(self, es): + @run_async + async def test_deletes_all_indices(self, es): + es.indices.delete.return_value = as_future() + r = runner.DeleteIndex() params = { @@ -1835,7 +2024,7 @@ def test_deletes_all_indices(self, es): } } - result = r(es, params) + result = await r(es, params) self.assertEqual((2, "ops"), result) @@ -1848,7 +2037,10 @@ def test_deletes_all_indices(self, es): class CreateIndexTemplateRunnerTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_create_index_templates(self, es): + @run_async + async def test_create_index_templates(self, es): + es.indices.put_template.return_value = as_future() + r = runner.CreateIndexTemplate() params = { @@ -1862,7 +2054,7 @@ def test_create_index_templates(self, es): } } - result = r(es, params) + result = await r(es, params) self.assertEqual((2, "ops"), result) @@ -1872,21 +2064,28 @@ def test_create_index_templates(self, es): ]) @mock.patch("elasticsearch.Elasticsearch") - def test_param_templates_mandatory(self, es): + @run_async + async def test_param_templates_mandatory(self, es): + es.indices.put_template.return_value = as_future() + r = runner.CreateIndexTemplate() params = {} with self.assertRaisesRegex(exceptions.DataError, "Parameter source for operation 'create-index-template' did not provide the mandatory parameter " "'templates'. Please add it to your parameter source."): - r(es, params) + await r(es, params) self.assertEqual(0, es.indices.put_template.call_count) class DeleteIndexTemplateRunnerTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_deletes_all_index_templates(self, es): + @run_async + async def test_deletes_all_index_templates(self, es): + es.indices.delete_template.return_value = as_future() + es.indices.delete.return_value = as_future() + r = runner.DeleteIndexTemplate() params = { @@ -1898,7 +2097,7 @@ def test_deletes_all_index_templates(self, es): "timeout": 60 } } - result = r(es, params) + result = await r(es, params) # 2 times delete index template, one time delete matching indices self.assertEqual((3, "ops"), result) @@ -1910,8 +2109,10 @@ def test_deletes_all_index_templates(self, es): es.indices.delete.assert_called_once_with(index="logs-*") @mock.patch("elasticsearch.Elasticsearch") - def test_deletes_only_existing_index_templates(self, es): - es.indices.exists_template.side_effect = [False, True] + @run_async + async def test_deletes_only_existing_index_templates(self, es): + es.indices.exists_template.side_effect = [as_future(False), as_future(True)] + es.indices.delete_template.return_value = as_future() r = runner.DeleteIndexTemplate() @@ -1926,7 +2127,7 @@ def test_deletes_only_existing_index_templates(self, es): }, "only-if-exists": True } - result = r(es, params) + result = await r(es, params) # 2 times delete index template, one time delete matching indices self.assertEqual((1, "ops"), result) @@ -1936,21 +2137,25 @@ def test_deletes_only_existing_index_templates(self, es): self.assertEqual(0, es.indices.delete.call_count) @mock.patch("elasticsearch.Elasticsearch") - def test_param_templates_mandatory(self, es): + @run_async + async def test_param_templates_mandatory(self, es): r = runner.DeleteIndexTemplate() params = {} with self.assertRaisesRegex(exceptions.DataError, "Parameter source for operation 'delete-index-template' did not provide the mandatory parameter " "'templates'. Please add it to your parameter source."): - r(es, params) + await r(es, params) self.assertEqual(0, es.indices.delete_template.call_count) class CreateMlDatafeedTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_create_ml_datafeed(self, es): + @run_async + async def test_create_ml_datafeed(self, es): + es.xpack.ml.put_datafeed.return_value = as_future() + params = { "datafeed-id": "some-data-feed", "body": { @@ -1960,13 +2165,15 @@ def test_create_ml_datafeed(self, es): } r = runner.CreateMlDatafeed() - r(es, params) + await r(es, params) es.xpack.ml.put_datafeed.assert_called_once_with(datafeed_id=params["datafeed-id"], body=params["body"]) @mock.patch("elasticsearch.Elasticsearch") - def test_create_ml_datafeed_fallback(self, es): - es.xpack.ml.put_datafeed.side_effect = elasticsearch.TransportError(400, "Bad Request") + @run_async + async def test_create_ml_datafeed_fallback(self, es): + es.xpack.ml.put_datafeed.side_effect = as_future(exception=elasticsearch.TransportError(400, "Bad Request")) + es.transport.perform_request.return_value = as_future() datafeed_id = "some-data-feed" body = { "job_id": "total-requests", @@ -1978,7 +2185,7 @@ def test_create_ml_datafeed_fallback(self, es): } r = runner.CreateMlDatafeed() - r(es, params) + await r(es, params) es.transport.perform_request.assert_called_once_with("PUT", "/_xpack/ml/datafeeds/%s" % datafeed_id, @@ -1988,27 +2195,32 @@ def test_create_ml_datafeed_fallback(self, es): class DeleteMlDatafeedTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_delete_ml_datafeed(self, es): + @run_async + async def test_delete_ml_datafeed(self, es): + es.xpack.ml.delete_datafeed.return_value = as_future() + datafeed_id = "some-data-feed" params = { "datafeed-id": datafeed_id } r = runner.DeleteMlDatafeed() - r(es, params) + await r(es, params) es.xpack.ml.delete_datafeed.assert_called_once_with(datafeed_id=datafeed_id, force=False, ignore=[404]) @mock.patch("elasticsearch.Elasticsearch") - def test_delete_ml_datafeed_fallback(self, es): - es.xpack.ml.delete_datafeed.side_effect = elasticsearch.TransportError(400, "Bad Request") + @run_async + async def test_delete_ml_datafeed_fallback(self, es): + es.xpack.ml.delete_datafeed.side_effect = as_future(exception=elasticsearch.TransportError(400, "Bad Request")) + es.transport.perform_request.return_value = as_future() datafeed_id = "some-data-feed" params = { "datafeed-id": datafeed_id, } r = runner.DeleteMlDatafeed() - r(es, params) + await r(es, params) es.transport.perform_request.assert_called_once_with("DELETE", "/_xpack/ml/datafeeds/%s" % datafeed_id, @@ -2017,7 +2229,9 @@ def test_delete_ml_datafeed_fallback(self, es): class StartMlDatafeedTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_start_ml_datafeed_with_body(self, es): + @run_async + async def test_start_ml_datafeed_with_body(self, es): + es.xpack.ml.start_datafeed.return_value = as_future() params = { "datafeed-id": "some-data-feed", "body": { @@ -2026,7 +2240,7 @@ def test_start_ml_datafeed_with_body(self, es): } r = runner.StartMlDatafeed() - r(es, params) + await r(es, params) es.xpack.ml.start_datafeed.assert_called_once_with(datafeed_id=params["datafeed-id"], body=params["body"], @@ -2035,8 +2249,10 @@ def test_start_ml_datafeed_with_body(self, es): timeout=None) @mock.patch("elasticsearch.Elasticsearch") - def test_start_ml_datafeed_with_body_fallback(self, es): - es.xpack.ml.start_datafeed.side_effect = elasticsearch.TransportError(400, "Bad Request") + @run_async + async def test_start_ml_datafeed_with_body_fallback(self, es): + es.xpack.ml.start_datafeed.side_effect = as_future(exception=elasticsearch.TransportError(400, "Bad Request")) + es.transport.perform_request.return_value = as_future() body = { "end": "now" } @@ -2046,7 +2262,7 @@ def test_start_ml_datafeed_with_body_fallback(self, es): } r = runner.StartMlDatafeed() - r(es, params) + await r(es, params) es.transport.perform_request.assert_called_once_with("POST", "/_xpack/ml/datafeeds/%s/_start" % params["datafeed-id"], @@ -2054,7 +2270,9 @@ def test_start_ml_datafeed_with_body_fallback(self, es): params=params) @mock.patch("elasticsearch.Elasticsearch") - def test_start_ml_datafeed_with_params(self, es): + @run_async + async def test_start_ml_datafeed_with_params(self, es): + es.xpack.ml.start_datafeed.return_value = as_future() params = { "datafeed-id": "some-data-feed", "start": "2017-01-01T01:00:00Z", @@ -2063,7 +2281,7 @@ def test_start_ml_datafeed_with_params(self, es): } r = runner.StartMlDatafeed() - r(es, params) + await r(es, params) es.xpack.ml.start_datafeed.assert_called_once_with(datafeed_id=params["datafeed-id"], body=None, @@ -2072,11 +2290,11 @@ def test_start_ml_datafeed_with_params(self, es): timeout=params["timeout"]) -class StopMlDatafeedTests: +class StopMlDatafeedTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - @pytest.mark.parametrize("seed", range(20)) - def test_stop_ml_datafeed(self, es, seed): - random.seed(seed) + @run_async + async def test_stop_ml_datafeed(self, es): + es.xpack.ml.stop_datafeed.return_value = as_future() params = { "datafeed-id": "some-data-feed", "force": random.choice([False, True]), @@ -2084,17 +2302,18 @@ def test_stop_ml_datafeed(self, es, seed): } r = runner.StopMlDatafeed() - r(es, params) + await r(es, params) es.xpack.ml.stop_datafeed.assert_called_once_with(datafeed_id=params["datafeed-id"], force=params["force"], timeout=params["timeout"]) @mock.patch("elasticsearch.Elasticsearch") - @pytest.mark.parametrize("seed", range(20)) - def test_stop_ml_datafeed_fallback(self, es, seed): - random.seed(seed) - es.xpack.ml.stop_datafeed.side_effect = elasticsearch.TransportError(400, "Bad Request") + @run_async + async def test_stop_ml_datafeed_fallback(self, es): + es.xpack.ml.stop_datafeed.side_effect = as_future(exception=elasticsearch.TransportError(400, "Bad Request")) + es.transport.perform_request.return_value = as_future() + params = { "datafeed-id": "some-data-feed", "force": random.choice([False, True]), @@ -2102,7 +2321,7 @@ def test_stop_ml_datafeed_fallback(self, es, seed): } r = runner.StopMlDatafeed() - r(es, params) + await r(es, params) es.transport.perform_request.assert_called_once_with("POST", "/_xpack/ml/datafeeds/%s/_stop" % params["datafeed-id"], @@ -2111,7 +2330,10 @@ def test_stop_ml_datafeed_fallback(self, es, seed): class CreateMlJobTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_create_ml_job(self, es): + @run_async + async def test_create_ml_job(self, es): + es.xpack.ml.put_job.return_value = as_future() + params = { "job-id": "an-ml-job", "body": { @@ -2134,13 +2356,16 @@ def test_create_ml_job(self, es): } r = runner.CreateMlJob() - r(es, params) + await r(es, params) es.xpack.ml.put_job.assert_called_once_with(job_id=params["job-id"], body=params["body"]) @mock.patch("elasticsearch.Elasticsearch") - def test_create_ml_job_fallback(self, es): - es.xpack.ml.put_job.side_effect = elasticsearch.TransportError(400, "Bad Request") + @run_async + async def test_create_ml_job_fallback(self, es): + es.xpack.ml.put_job.side_effect = as_future(exception=elasticsearch.TransportError(400, "Bad Request")) + es.transport.perform_request.return_value = as_future() + body = { "description": "Total sum of requests", "analysis_config": { @@ -2164,7 +2389,7 @@ def test_create_ml_job_fallback(self, es): } r = runner.CreateMlJob() - r(es, params) + await r(es, params) es.transport.perform_request.assert_called_once_with("PUT", "/_xpack/ml/anomaly_detectors/%s" % params["job-id"], @@ -2174,20 +2399,25 @@ def test_create_ml_job_fallback(self, es): class DeleteMlJobTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_delete_ml_job(self, es): + @run_async + async def test_delete_ml_job(self, es): + es.xpack.ml.delete_job.return_value = as_future() + job_id = "an-ml-job" params = { "job-id": job_id } r = runner.DeleteMlJob() - r(es, params) + await r(es, params) es.xpack.ml.delete_job.assert_called_once_with(job_id=job_id, force=False, ignore=[404]) @mock.patch("elasticsearch.Elasticsearch") - def test_delete_ml_job_fallback(self, es): - es.xpack.ml.delete_job.side_effect = elasticsearch.TransportError(400, "Bad Request") + @run_async + async def test_delete_ml_job_fallback(self, es): + es.xpack.ml.delete_job.side_effect = as_future(exception=elasticsearch.TransportError(400, "Bad Request")) + es.transport.perform_request.return_value = as_future() job_id = "an-ml-job" params = { @@ -2195,7 +2425,7 @@ def test_delete_ml_job_fallback(self, es): } r = runner.DeleteMlJob() - r(es, params) + await r(es, params) es.transport.perform_request.assert_called_once_with("DELETE", "/_xpack/ml/anomaly_detectors/%s" % params["job-id"], @@ -2204,20 +2434,25 @@ def test_delete_ml_job_fallback(self, es): class OpenMlJobTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_open_ml_job(self, es): + @run_async + async def test_open_ml_job(self, es): + es.xpack.ml.open_job.return_value = as_future() + job_id = "an-ml-job" params = { "job-id": job_id } r = runner.OpenMlJob() - r(es, params) + await r(es, params) es.xpack.ml.open_job.assert_called_once_with(job_id=job_id) @mock.patch("elasticsearch.Elasticsearch") - def test_open_ml_job_fallback(self, es): - es.xpack.ml.open_job.side_effect = elasticsearch.TransportError(400, "Bad Request") + @run_async + async def test_open_ml_job_fallback(self, es): + es.xpack.ml.open_job.side_effect = as_future(exception=elasticsearch.TransportError(400, "Bad Request")) + es.transport.perform_request.return_value = as_future() job_id = "an-ml-job" params = { @@ -2225,18 +2460,18 @@ def test_open_ml_job_fallback(self, es): } r = runner.OpenMlJob() - r(es, params) + await r(es, params) es.transport.perform_request.assert_called_once_with("POST", "/_xpack/ml/anomaly_detectors/%s/_open" % params["job-id"], params=params) -class CloseMlJobTests: +class CloseMlJobTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - @pytest.mark.parametrize("seed", range(20)) - def test_close_ml_job(self, es, seed): - random.seed(seed) + @run_async + async def test_close_ml_job(self, es): + es.xpack.ml.close_job.return_value = as_future() params = { "job-id": "an-ml-job", "force": random.choice([False, True]), @@ -2244,15 +2479,15 @@ def test_close_ml_job(self, es, seed): } r = runner.CloseMlJob() - r(es, params) + await r(es, params) es.xpack.ml.close_job.assert_called_once_with(job_id=params["job-id"], force=params["force"], timeout=params["timeout"]) @mock.patch("elasticsearch.Elasticsearch") - @pytest.mark.parametrize("seed", range(20)) - def test_close_ml_job_fallback(self, es, seed): - random.seed(seed) - es.xpack.ml.close_job.side_effect = elasticsearch.TransportError(400, "Bad Request") + @run_async + async def test_close_ml_job_fallback(self, es): + es.xpack.ml.close_job.side_effect = as_future(exception=elasticsearch.TransportError(400, "Bad Request")) + es.transport.perform_request.return_value = as_future() params = { "job-id": "an-ml-job", @@ -2261,7 +2496,7 @@ def test_close_ml_job_fallback(self, es, seed): } r = runner.CloseMlJob() - r(es, params) + await r(es, params) es.transport.perform_request.assert_called_once_with("POST", "/_xpack/ml/anomaly_detectors/%s/_close" % params["job-id"], @@ -2270,13 +2505,15 @@ def test_close_ml_job_fallback(self, es, seed): class RawRequestRunnerTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_issue_request_with_defaults(self, es): + @run_async + async def test_issue_request_with_defaults(self, es): + es.transport.perform_request.return_value = as_future() r = runner.RawRequest() params = { "path": "/_cat/count" } - r(es, params) + await r(es, params) es.transport.perform_request.assert_called_once_with(method="GET", url="/_cat/count", @@ -2285,7 +2522,9 @@ def test_issue_request_with_defaults(self, es): params={}) @mock.patch("elasticsearch.Elasticsearch") - def test_issue_delete_index(self, es): + @run_async + async def test_issue_delete_index(self, es): + es.transport.perform_request.return_value = as_future() r = runner.RawRequest() params = { @@ -2296,7 +2535,7 @@ def test_issue_delete_index(self, es): "pretty": "true" } } - r(es, params) + await r(es, params) es.transport.perform_request.assert_called_once_with(method="DELETE", url="/twitter", @@ -2305,7 +2544,9 @@ def test_issue_delete_index(self, es): params={"ignore": [400, 404], "pretty": "true"}) @mock.patch("elasticsearch.Elasticsearch") - def test_issue_create_index(self, es): + @run_async + async def test_issue_create_index(self, es): + es.transport.perform_request.return_value = as_future() r = runner.RawRequest() params = { @@ -2319,7 +2560,7 @@ def test_issue_create_index(self, es): } } } - r(es, params) + await r(es, params) es.transport.perform_request.assert_called_once_with(method="POST", url="/twitter", @@ -2334,7 +2575,9 @@ def test_issue_create_index(self, es): params={}) @mock.patch("elasticsearch.Elasticsearch") - def test_issue_msearch(self, es): + @run_async + async def test_issue_msearch(self, es): + es.transport.perform_request.return_value = as_future() r = runner.RawRequest() params = { @@ -2349,7 +2592,7 @@ def test_issue_msearch(self, es): {"query": {"match_all": {}}} ] } - r(es, params) + await r(es, params) es.transport.perform_request.assert_called_once_with(method="GET", url="/_msearch", @@ -2366,23 +2609,25 @@ def test_issue_msearch(self, es): class SleepTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") # To avoid real sleeps in unit tests - @mock.patch("time.sleep") - def test_missing_parameter(self, sleep, es): + @mock.patch("asyncio.sleep", return_value=as_future()) + @run_async + async def test_missing_parameter(self, sleep, es): r = runner.Sleep() with self.assertRaisesRegex(exceptions.DataError, "Parameter source for operation 'sleep' did not provide the mandatory parameter " "'duration'. Please add it to your parameter source."): - r(es, params={}) + await r(es, params={}) self.assertEqual(0, es.call_count) self.assertEqual(0, sleep.call_count) @mock.patch("elasticsearch.Elasticsearch") # To avoid real sleeps in unit tests - @mock.patch("time.sleep") - def test_sleep(self, sleep, es): + @mock.patch("asyncio.sleep", return_value=as_future()) + @run_async + async def test_sleep(self, sleep, es): r = runner.Sleep() - r(es, params={"duration": 4.3}) + await r(es, params={"duration": 4.3}) self.assertEqual(0, es.call_count) sleep.assert_called_once_with(4.3) @@ -2390,20 +2635,24 @@ def test_sleep(self, sleep, es): class DeleteSnapshotRepositoryTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_delete_snapshot_repository(self, es): + @run_async + async def test_delete_snapshot_repository(self, es): + es.snapshot.delete_repository.return_value = as_future() params = { "repository": "backups" } r = runner.DeleteSnapshotRepository() - r(es, params) + await r(es, params) es.snapshot.delete_repository.assert_called_once_with(repository="backups") class CreateSnapshotRepositoryTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_create_snapshot_repository(self, es): + @run_async + async def test_create_snapshot_repository(self, es): + es.snapshot.create_repository.return_value = as_future() params = { "repository": "backups", "body": { @@ -2415,7 +2664,7 @@ def test_create_snapshot_repository(self, es): } r = runner.CreateSnapshotRepository() - r(es, params) + await r(es, params) es.snapshot.create_repository.assert_called_once_with(repository="backups", body={ @@ -2429,7 +2678,10 @@ def test_create_snapshot_repository(self, es): class RestoreSnapshotTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_restore_snapshot(self, es): + @run_async + async def test_restore_snapshot(self, es): + es.snapshot.restore.return_value = as_future() + params = { "repository": "backups", "snapshot": "snapshot-001", @@ -2440,7 +2692,7 @@ def test_restore_snapshot(self, es): } r = runner.RestoreSnapshot() - r(es, params) + await r(es, params) es.snapshot.restore.assert_called_once_with(repository="backups", snapshot="snapshot-001", @@ -2449,7 +2701,9 @@ def test_restore_snapshot(self, es): params={"request_timeout": 7200}) @mock.patch("elasticsearch.Elasticsearch") - def test_restore_snapshot_with_body(self, es): + @run_async + async def test_restore_snapshot_with_body(self, es): + es.snapshot.restore.return_value = as_future() params = { "repository": "backups", "snapshot": "snapshot-001", @@ -2467,7 +2721,7 @@ def test_restore_snapshot_with_body(self, es): } r = runner.RestoreSnapshot() - r(es, params) + await r(es, params) es.snapshot.restore.assert_called_once_with(repository="backups", snapshot="snapshot-001", @@ -2484,15 +2738,16 @@ def test_restore_snapshot_with_body(self, es): class IndicesRecoveryTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_indices_recovery_already_finished(self, es): + @run_async + async def test_indices_recovery_already_finished(self, es): # empty response - es.indices.recovery.return_value = {} + es.indices.recovery.return_value = as_future({}) r = runner.IndicesRecovery() self.assertFalse(r.completed) self.assertEqual(r.percent_completed, 0.0) - r(es, { + await r(es, { "completion-recheck-wait-period": 0 }) @@ -2504,11 +2759,12 @@ def test_indices_recovery_already_finished(self, es): self.assertEqual(3, es.indices.recovery.call_count) @mock.patch("elasticsearch.Elasticsearch") - def test_waits_for_ongoing_indices_recovery(self, es): + @run_async + async def test_waits_for_ongoing_indices_recovery(self, es): # empty response es.indices.recovery.side_effect = [ # active recovery - { + as_future({ "index1": { "shards": [ { @@ -2535,11 +2791,11 @@ def test_waits_for_ongoing_indices_recovery(self, es): } ] } - }, + }), # completed - will be called three times - {}, - {}, - {}, + as_future({}), + as_future({}), + as_future({}), ] r = runner.IndicesRecovery() @@ -2547,7 +2803,7 @@ def test_waits_for_ongoing_indices_recovery(self, es): self.assertEqual(r.percent_completed, 0.0) while not r.completed: - recovered_bytes, unit = r(es, { + recovered_bytes, unit = await r(es, { "completion-recheck-wait-period": 0 }) if r.completed: @@ -2564,13 +2820,16 @@ def test_waits_for_ongoing_indices_recovery(self, es): class ShrinkIndexTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") # To avoid real sleeps in unit tests - @mock.patch("time.sleep") - def test_shrink_index_with_shrink_node(self, sleep, es): + @mock.patch("asyncio.sleep", return_value=as_future()) + @run_async + async def test_shrink_index_with_shrink_node(self, sleep, es): # cluster health API - es.cluster.health.return_value = { + es.cluster.health.return_value = as_future({ "status": "green", "relocating_shards": 0 - } + }) + es.indices.put_settings.return_value = as_future() + es.indices.shrink.return_value = as_future() r = runner.ShrinkIndex() params = { @@ -2585,7 +2844,7 @@ def test_shrink_index_with_shrink_node(self, sleep, es): "shrink-node": "rally-node-0" } - r(es, params) + await r(es, params) es.indices.put_settings.assert_called_once_with(index="src", body={ @@ -2612,14 +2871,15 @@ def test_shrink_index_with_shrink_node(self, sleep, es): @mock.patch("elasticsearch.Elasticsearch") # To avoid real sleeps in unit tests - @mock.patch("time.sleep") - def test_shrink_index_derives_shrink_node(self, sleep, es): + @mock.patch("asyncio.sleep", return_value=as_future()) + @run_async + async def test_shrink_index_derives_shrink_node(self, sleep, es): # cluster health API - es.cluster.health.return_value = { + es.cluster.health.return_value = as_future({ "status": "green", "relocating_shards": 0 - } - es.nodes.info.return_value = { + }) + es.nodes.info.return_value = as_future({ "_nodes": { "total": 3, "successful": 3, @@ -2648,7 +2908,9 @@ def test_shrink_index_derives_shrink_node(self, sleep, es): ] } } - } + }) + es.indices.put_settings.return_value = as_future() + es.indices.shrink.return_value = as_future() r = runner.ShrinkIndex() params = { @@ -2662,7 +2924,7 @@ def test_shrink_index_derives_shrink_node(self, sleep, es): } } - r(es, params) + await r(es, params) es.indices.put_settings.assert_called_once_with(index="src", body={ @@ -2691,7 +2953,9 @@ def test_shrink_index_derives_shrink_node(self, sleep, es): class PutSettingsTests(TestCase): @mock.patch("elasticsearch.Elasticsearch") - def test_put_settings(self, es): + @run_async + async def test_put_settings(self, es): + es.cluster.put_settings.return_value = as_future() params = { "body": { "transient": { @@ -2701,7 +2965,7 @@ def test_put_settings(self, es): } r = runner.PutSettings() - r(es, params) + await r(es, params) es.cluster.put_settings.assert_called_once_with(body={ "transient": { @@ -2711,22 +2975,24 @@ def test_put_settings(self, es): class RetryTests(TestCase): - def test_is_transparent_on_success_when_no_retries(self): - delegate = mock.Mock() + @run_async + async def test_is_transparent_on_success_when_no_retries(self): + delegate = mock.Mock(return_value=as_future()) es = None params = { # no retries } retrier = runner.Retry(delegate) - retrier(es, params) + await retrier(es, params) delegate.assert_called_once_with(es, params) - def test_is_transparent_on_exception_when_no_retries(self): + @run_async + async def test_is_transparent_on_exception_when_no_retries(self): import elasticsearch - delegate = mock.Mock(side_effect=elasticsearch.ConnectionError("N/A", "no route to host")) + delegate = mock.Mock(side_effect=as_future(exception=elasticsearch.ConnectionError("N/A", "no route to host"))) es = None params = { # no retries @@ -2734,27 +3000,29 @@ def test_is_transparent_on_exception_when_no_retries(self): retrier = runner.Retry(delegate) with self.assertRaises(elasticsearch.ConnectionError): - retrier(es, params) + await retrier(es, params) delegate.assert_called_once_with(es, params) - def test_is_transparent_on_application_error_when_no_retries(self): + @run_async + async def test_is_transparent_on_application_error_when_no_retries(self): original_return_value = {"weight": 1, "unit": "ops", "success": False} - delegate = mock.Mock(return_value=original_return_value) + delegate = mock.Mock(return_value=as_future(original_return_value)) es = None params = { # no retries } retrier = runner.Retry(delegate) - result = retrier(es, params) + result = await retrier(es, params) self.assertEqual(original_return_value, result) delegate.assert_called_once_with(es, params) - def test_is_does_not_retry_on_success(self): - delegate = mock.Mock() + @run_async + async def test_is_does_not_retry_on_success(self): + delegate = mock.Mock(return_value=as_future()) es = None params = { "retries": 3, @@ -2764,14 +3032,20 @@ def test_is_does_not_retry_on_success(self): } retrier = runner.Retry(delegate) - retrier(es, params) + await retrier(es, params) delegate.assert_called_once_with(es, params) - def test_retries_on_timeout_if_wanted_and_raises_if_no_recovery(self): + @run_async + async def test_retries_on_timeout_if_wanted_and_raises_if_no_recovery(self): import elasticsearch - delegate = mock.Mock(side_effect=elasticsearch.ConnectionError("N/A", "no route to host")) + delegate = mock.Mock(side_effect=[ + as_future(exception=elasticsearch.ConnectionError("N/A", "no route to host")), + as_future(exception=elasticsearch.ConnectionError("N/A", "no route to host")), + as_future(exception=elasticsearch.ConnectionError("N/A", "no route to host")), + as_future(exception=elasticsearch.ConnectionError("N/A", "no route to host")) + ]) es = None params = { "retries": 3, @@ -2782,7 +3056,7 @@ def test_retries_on_timeout_if_wanted_and_raises_if_no_recovery(self): retrier = runner.Retry(delegate) with self.assertRaises(elasticsearch.ConnectionError): - retrier(es, params) + await retrier(es, params) delegate.assert_has_calls([ mock.call(es, params), @@ -2790,11 +3064,15 @@ def test_retries_on_timeout_if_wanted_and_raises_if_no_recovery(self): mock.call(es, params) ]) - def test_retries_on_timeout_if_wanted_and_returns_first_call(self): + @run_async + async def test_retries_on_timeout_if_wanted_and_returns_first_call(self): import elasticsearch failed_return_value = {"weight": 1, "unit": "ops", "success": False} - delegate = mock.Mock(side_effect=[elasticsearch.ConnectionError("N/A", "no route to host"), failed_return_value]) + delegate = mock.Mock(side_effect=[ + as_future(exception=elasticsearch.ConnectionError("N/A", "no route to host")), + as_future(failed_return_value) + ]) es = None params = { "retries": 3, @@ -2804,7 +3082,7 @@ def test_retries_on_timeout_if_wanted_and_returns_first_call(self): } retrier = runner.Retry(delegate) - result = retrier(es, params) + result = await retrier(es, params) self.assertEqual(failed_return_value, result) delegate.assert_has_calls([ @@ -2814,20 +3092,21 @@ def test_retries_on_timeout_if_wanted_and_returns_first_call(self): mock.call(es, params) ]) - def test_retries_mixed_timeout_and_application_errors(self): + @run_async + async def test_retries_mixed_timeout_and_application_errors(self): import elasticsearch connection_error = elasticsearch.ConnectionError("N/A", "no route to host") failed_return_value = {"weight": 1, "unit": "ops", "success": False} success_return_value = {"weight": 1, "unit": "ops", "success": False} delegate = mock.Mock(side_effect=[ - connection_error, - failed_return_value, - connection_error, - connection_error, - failed_return_value, - success_return_value] - ) + as_future(exception=connection_error), + as_future(failed_return_value), + as_future(exception=connection_error), + as_future(exception=connection_error), + as_future(failed_return_value), + as_future(success_return_value) + ]) es = None params = { # we try exactly as often as there are errors to also test the semantics of "retry". @@ -2838,7 +3117,7 @@ def test_retries_mixed_timeout_and_application_errors(self): } retrier = runner.Retry(delegate) - result = retrier(es, params) + result = await retrier(es, params) self.assertEqual(success_return_value, result) delegate.assert_has_calls([ @@ -2856,10 +3135,11 @@ def test_retries_mixed_timeout_and_application_errors(self): mock.call(es, params) ]) - def test_does_not_retry_on_timeout_if_not_wanted(self): + @run_async + async def test_does_not_retry_on_timeout_if_not_wanted(self): import elasticsearch - delegate = mock.Mock(side_effect=elasticsearch.ConnectionTimeout(408, "timed out")) + delegate = mock.Mock(side_effect=as_future(exception=elasticsearch.ConnectionTimeout(408, "timed out"))) es = None params = { "retries": 3, @@ -2870,15 +3150,19 @@ def test_does_not_retry_on_timeout_if_not_wanted(self): retrier = runner.Retry(delegate) with self.assertRaises(elasticsearch.ConnectionTimeout): - retrier(es, params) + await retrier(es, params) delegate.assert_called_once_with(es, params) - def test_retries_on_application_error_if_wanted(self): + @run_async + async def test_retries_on_application_error_if_wanted(self): failed_return_value = {"weight": 1, "unit": "ops", "success": False} success_return_value = {"weight": 1, "unit": "ops", "success": True} - delegate = mock.Mock(side_effect=[failed_return_value, success_return_value]) + delegate = mock.Mock(side_effect=[ + as_future(failed_return_value), + as_future(success_return_value) + ]) es = None params = { "retries": 3, @@ -2888,7 +3172,7 @@ def test_retries_on_application_error_if_wanted(self): } retrier = runner.Retry(delegate) - result = retrier(es, params) + result = await retrier(es, params) self.assertEqual(success_return_value, result) @@ -2898,10 +3182,11 @@ def test_retries_on_application_error_if_wanted(self): mock.call(es, params) ]) - def test_does_not_retry_on_application_error_if_not_wanted(self): + @run_async + async def test_does_not_retry_on_application_error_if_not_wanted(self): failed_return_value = {"weight": 1, "unit": "ops", "success": False} - delegate = mock.Mock(return_value=failed_return_value) + delegate = mock.Mock(return_value=as_future(failed_return_value)) es = None params = { "retries": 3, @@ -2911,14 +3196,15 @@ def test_does_not_retry_on_application_error_if_not_wanted(self): } retrier = runner.Retry(delegate) - result = retrier(es, params) + result = await retrier(es, params) self.assertEqual(failed_return_value, result) delegate.assert_called_once_with(es, params) - def test_assumes_success_if_runner_returns_non_dict(self): - delegate = mock.Mock(return_value=(1, "ops")) + @run_async + async def test_assumes_success_if_runner_returns_non_dict(self): + delegate = mock.Mock(return_value=as_future(result=(1, "ops"))) es = None params = { "retries": 3, @@ -2928,21 +3214,22 @@ def test_assumes_success_if_runner_returns_non_dict(self): } retrier = runner.Retry(delegate) - result = retrier(es, params) + result = await retrier(es, params) self.assertEqual((1, "ops"), result) delegate.assert_called_once_with(es, params) - def test_retries_until_success(self): + @run_async + async def test_retries_until_success(self): failure_count = 5 failed_return_value = {"weight": 1, "unit": "ops", "success": False} success_return_value = {"weight": 1, "unit": "ops", "success": True} responses = [] - responses += failure_count * [failed_return_value] - responses += [success_return_value] + responses += failure_count * [as_future(failed_return_value)] + responses += [as_future(success_return_value)] delegate = mock.Mock(side_effect=responses) es = None @@ -2952,7 +3239,7 @@ def test_retries_until_success(self): } retrier = runner.Retry(delegate) - result = retrier(es, params) + result = await retrier(es, params) self.assertEqual(success_return_value, result) diff --git a/tests/track/params_test.py b/tests/track/params_test.py index f182c30b0..fe1b5e04b 100644 --- a/tests/track/params_test.py +++ b/tests/track/params_test.py @@ -282,11 +282,11 @@ def idx(id): class IndexDataReaderTests(TestCase): def test_read_bulk_larger_than_number_of_docs(self): data = [ - '{"key": "value1"}\n', - '{"key": "value2"}\n', - '{"key": "value3"}\n', - '{"key": "value4"}\n', - '{"key": "value5"}\n' + b'{"key": "value1"}\n', + b'{"key": "value2"}\n', + b'{"key": "value3"}\n', + b'{"key": "value4"}\n', + b'{"key": "value5"}\n' ] bulk_size = 50 @@ -308,11 +308,11 @@ def test_read_bulk_larger_than_number_of_docs(self): def test_read_bulk_with_offset(self): data = [ - '{"key": "value1"}\n', - '{"key": "value2"}\n', - '{"key": "value3"}\n', - '{"key": "value4"}\n', - '{"key": "value5"}\n' + b'{"key": "value1"}\n', + b'{"key": "value2"}\n', + b'{"key": "value3"}\n', + b'{"key": "value4"}\n', + b'{"key": "value5"}\n' ] bulk_size = 50 @@ -334,13 +334,13 @@ def test_read_bulk_with_offset(self): def test_read_bulk_smaller_than_number_of_docs(self): data = [ - '{"key": "value1"}\n', - '{"key": "value2"}\n', - '{"key": "value3"}\n', - '{"key": "value4"}\n', - '{"key": "value5"}\n', - '{"key": "value6"}\n', - '{"key": "value7"}\n', + b'{"key": "value1"}\n', + b'{"key": "value2"}\n', + b'{"key": "value3"}\n', + b'{"key": "value4"}\n', + b'{"key": "value5"}\n', + b'{"key": "value6"}\n', + b'{"key": "value7"}\n', ] bulk_size = 3 @@ -362,13 +362,13 @@ def test_read_bulk_smaller_than_number_of_docs(self): def test_read_bulk_smaller_than_number_of_docs_and_multiple_clients(self): data = [ - '{"key": "value1"}\n', - '{"key": "value2"}\n', - '{"key": "value3"}\n', - '{"key": "value4"}\n', - '{"key": "value5"}\n', - '{"key": "value6"}\n', - '{"key": "value7"}\n', + b'{"key": "value1"}\n', + b'{"key": "value2"}\n', + b'{"key": "value3"}\n', + b'{"key": "value4"}\n', + b'{"key": "value5"}\n', + b'{"key": "value6"}\n', + b'{"key": "value7"}\n', ] bulk_size = 3 @@ -391,20 +391,20 @@ def test_read_bulk_smaller_than_number_of_docs_and_multiple_clients(self): def test_read_bulks_and_assume_metadata_line_in_source_file(self): data = [ - '{"index": {"_index": "test_index", "_type": "test_type"}\n', - '{"key": "value1"}\n', - '{"index": {"_index": "test_index", "_type": "test_type"}\n', - '{"key": "value2"}\n', - '{"index": {"_index": "test_index", "_type": "test_type"}\n', - '{"key": "value3"}\n', - '{"index": {"_index": "test_index", "_type": "test_type"}\n', - '{"key": "value4"}\n', - '{"index": {"_index": "test_index", "_type": "test_type"}\n', - '{"key": "value5"}\n', - '{"index": {"_index": "test_index", "_type": "test_type"}\n', - '{"key": "value6"}\n', - '{"index": {"_index": "test_index", "_type": "test_type"}\n', - '{"key": "value7"}\n' + b'{"index": {"_index": "test_index", "_type": "test_type"}\n', + b'{"key": "value1"}\n', + b'{"index": {"_index": "test_index", "_type": "test_type"}\n', + b'{"key": "value2"}\n', + b'{"index": {"_index": "test_index", "_type": "test_type"}\n', + b'{"key": "value3"}\n', + b'{"index": {"_index": "test_index", "_type": "test_type"}\n', + b'{"key": "value4"}\n', + b'{"index": {"_index": "test_index", "_type": "test_type"}\n', + b'{"key": "value5"}\n', + b'{"index": {"_index": "test_index", "_type": "test_type"}\n', + b'{"key": "value6"}\n', + b'{"index": {"_index": "test_index", "_type": "test_type"}\n', + b'{"key": "value7"}\n' ] bulk_size = 3 @@ -439,11 +439,11 @@ def test_read_bulk_with_id_conflicts(self): 2]) data = [ - '{"key": "value1"}\n', - '{"key": "value2"}\n', - '{"key": "value3"}\n', - '{"key": "value4"}\n', - '{"key": "value5"}\n' + b'{"key": "value1"}\n', + b'{"key": "value2"}\n', + b'{"key": "value3"}\n', + b'{"key": "value4"}\n', + b'{"key": "value5"}\n' ] bulk_size = 2 @@ -471,24 +471,24 @@ def test_read_bulk_with_id_conflicts(self): bulks.append(bulk) self.assertEqual([ - '{"index": {"_index": "test_index", "_type": "test_type", "_id": "100"}}\n' + - '{"key": "value1"}\n' + - '{"update": {"_index": "test_index", "_type": "test_type", "_id": "200"}}\n' + - '{"doc":{"key": "value2"}}\n', - '{"update": {"_index": "test_index", "_type": "test_type", "_id": "400"}}\n' + - '{"doc":{"key": "value3"}}\n' + - '{"update": {"_index": "test_index", "_type": "test_type", "_id": "300"}}\n' + - '{"doc":{"key": "value4"}}\n', - '{"index": {"_index": "test_index", "_type": "test_type", "_id": "200"}}\n' + - '{"key": "value5"}\n' + b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "100"}}\n' + + b'{"key": "value1"}\n' + + b'{"update": {"_index": "test_index", "_type": "test_type", "_id": "200"}}\n' + + b'{"doc":{"key": "value2"}}\n', + b'{"update": {"_index": "test_index", "_type": "test_type", "_id": "400"}}\n' + + b'{"doc":{"key": "value3"}}\n' + + b'{"update": {"_index": "test_index", "_type": "test_type", "_id": "300"}}\n' + + b'{"doc":{"key": "value4"}}\n', + b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "200"}}\n' + + b'{"key": "value5"}\n' ], bulks) def test_read_bulk_with_external_id_and_zero_conflict_probability(self): data = [ - '{"key": "value1"}\n', - '{"key": "value2"}\n', - '{"key": "value3"}\n', - '{"key": "value4"}\n' + b'{"key": "value1"}\n', + b'{"key": "value2"}\n', + b'{"key": "value3"}\n', + b'{"key": "value4"}\n' ] bulk_size = 2 @@ -513,15 +513,15 @@ def test_read_bulk_with_external_id_and_zero_conflict_probability(self): bulks.append(bulk) self.assertEqual([ - '{"index": {"_index": "test_index", "_type": "test_type", "_id": "100"}}\n' + - '{"key": "value1"}\n' + - '{"index": {"_index": "test_index", "_type": "test_type", "_id": "200"}}\n' + - '{"key": "value2"}\n', - - '{"index": {"_index": "test_index", "_type": "test_type", "_id": "300"}}\n' + - '{"key": "value3"}\n' + - '{"index": {"_index": "test_index", "_type": "test_type", "_id": "400"}}\n' + - '{"key": "value4"}\n' + b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "100"}}\n' + + b'{"key": "value1"}\n' + + b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "200"}}\n' + + b'{"key": "value2"}\n', + + b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "300"}}\n' + + b'{"key": "value3"}\n' + + b'{"index": {"_index": "test_index", "_type": "test_type", "_id": "400"}}\n' + + b'{"key": "value4"}\n' ], bulks) def assert_bulks_sized(self, reader, expected_bulk_sizes, expected_line_sizes): @@ -531,7 +531,7 @@ def assert_bulks_sized(self, reader, expected_bulk_sizes, expected_line_sizes): for index, type, batch in reader: for bulk_size, bulk in batch: self.assertEqual(expected_bulk_sizes[bulk_index], bulk_size, msg="bulk size") - self.assertEqual(expected_line_sizes[bulk_index], bulk.count("\n")) + self.assertEqual(expected_line_sizes[bulk_index], bulk.count(b"\n")) bulk_index += 1 self.assertEqual(len(expected_bulk_sizes), bulk_index, "Not all bulk sizes have been checked")