diff --git a/batch/test/conftest.py b/batch/test/conftest.py index fc04bda3222..0d935045c08 100644 --- a/batch/test/conftest.py +++ b/batch/test/conftest.py @@ -1,3 +1,4 @@ +import asyncio import hashlib import logging import os @@ -9,6 +10,15 @@ log = logging.getLogger(__name__) +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop() + try: + yield loop + finally: + loop.close() + + @pytest.fixture(autouse=True) def log_before_after(): log.info('starting test') diff --git a/batch/test/failure_injecting_client_session.py b/batch/test/failure_injecting_client_session.py index e1bb25523d6..e5cc6b5c5ad 100644 --- a/batch/test/failure_injecting_client_session.py +++ b/batch/test/failure_injecting_client_session.py @@ -1,7 +1,6 @@ import aiohttp from hailtop import httpx -from hailtop.utils import async_to_blocking class FailureInjectingClientSession(httpx.ClientSession): @@ -10,11 +9,11 @@ def __init__(self, should_fail): self.should_fail = should_fail self.real_session = httpx.client_session() - def __enter__(self): + async def __aenter__(self): return self - def __exit__(self, exc_type, exc_value, traceback): - async_to_blocking(self.real_session.close()) + async def __aexit__(self, exc_type, exc_value, traceback): + await self.real_session.close() def maybe_fail(self, method, path, headers): if self.should_fail(): diff --git a/batch/test/test_batch.py b/batch/test/test_batch.py index 2a7684aea66..c026a765985 100644 --- a/batch/test/test_batch.py +++ b/batch/test/test_batch.py @@ -11,6 +11,7 @@ from hailtop.auth import hail_credentials from hailtop.batch.backend import HAIL_GENETICS_HAILTOP_IMAGE from hailtop.batch_client import BatchNotCreatedError, JobNotSubmittedError +from hailtop.batch_client.aioclient import BatchClient as AioBatchClient from hailtop.batch_client.client import Batch, BatchClient from hailtop.config import get_deploy_config from hailtop.test_utils import skip_in_azure @@ -1009,7 +1010,7 @@ def test_client_max_size(client: BatchClient): b.submit() -def test_restartable_insert(client: BatchClient): +async def test_restartable_insert(): i = 0 def every_third_time(): @@ -1019,19 +1020,19 @@ def every_third_time(): return True return False - with FailureInjectingClientSession(every_third_time) as session: - client = BatchClient('test', session=session) + async with FailureInjectingClientSession(every_third_time) as session: + client = await AioBatchClient.create('test', session=session) b = create_batch(client) for _ in range(9): b.create_job(DOCKER_ROOT_IMAGE, ['echo', 'a']) - b.submit(max_bunch_size=1) - b = client.get_batch(b.id) # get a batch untainted by the FailureInjectingClientSession - status = b.wait() - assert status['state'] == 'success', str((status, b.debug_info())) - jobs = list(b.jobs()) - assert len(jobs) == 9, str((jobs, b.debug_info())) + await b.submit(max_bunch_size=1) + b = await client.get_batch(b.id) # get a batch untainted by the FailureInjectingClientSession + status = await b.wait() + assert status['state'] == 'success', str((status, await b.debug_info())) + jobs = [x async for x in b.jobs()] + assert len(jobs) == 9, str((jobs, await b.debug_info())) def test_create_idempotence(client: BatchClient): diff --git a/hail/Makefile b/hail/Makefile index d79395060b3..30558d746b1 100644 --- a/hail/Makefile +++ b/hail/Makefile @@ -27,7 +27,8 @@ JAVAC ?= javac JAR ?= jar endif -PYTEST_TARGET ?= test/hail +PYTEST_TARGET ?= test/hail test/hailtop +PYTEST_INTER_CLOUD_TARGET ?= test/hailtop/inter_cloud # not perfect, not robust to simdpp changes, but probably fine BUILD_DEBUG_PREFIX := build/classes/scala/debug @@ -57,6 +58,7 @@ WHEEL := build/deploy/dist/hail-$(HAIL_PIP_VERSION)-py3-none-any.whl GRADLE_ARGS += -Dscala.version=$(SCALA_VERSION) -Dspark.version=$(SPARK_VERSION) -Delasticsearch.major-version=$(ELASTIC_MAJOR_VERSION) TEST_STORAGE_URI = $(shell kubectl get secret global-config --template={{.data.test_storage_uri}} | base64 --decode) +HAIL_TEST_GCS_BUCKET = $(shell kubectl get secret global-config --template={{.data.hail_test_gcs_bucket}} | base64 --decode) GCP_PROJECT = $(shell kubectl get secret global-config --template={{.data.gcp_project}} | base64 --decode) CLOUD_HAIL_TEST_RESOURCES_PREFIX = $(TEST_STORAGE_URI)/$(shell whoami)/hail-test-resources CLOUD_HAIL_TEST_RESOURCES_DIR = $(CLOUD_HAIL_TEST_RESOURCES_PREFIX)/test/resources/ @@ -118,7 +120,7 @@ services-jvm-test: $(SCALA_BUILD_INFO) $(JAR_SOURCES) $(JAR_TEST_SOURCES) ifdef HAIL_COMPILE_NATIVES fs-jvm-test: native-lib-prebuilt endif -fs-jvm-test: $(SCALA_BUILD_INFO) $(JAR_SOURCES) $(JAR_TEST_SOURCES) upload-qob-test-resources +fs-jvm-test: $(SCALA_BUILD_INFO) $(JAR_SOURCES) $(JAR_TEST_SOURCES) upload-remote-test-resources ! [ -z $(HAIL_CLOUD) ] # call like make fs-jvm-test HAIL_CLOUD=gcp or azure ! [ -z $(NAMESPACE) ] # call like make fs-jvm-test NAMEPSPACE=default HAIL_CLOUD=$(HAIL_CLOUD) \ @@ -180,7 +182,8 @@ python-jar: $(PYTHON_JAR) .PHONY: pytest pytest: $(PYTHON_VERSION_INFO) $(INIT_SCRIPTS) pytest: python/README.md $(FAST_PYTHON_JAR) $(FAST_PYTHON_JAR_EXTRA_CLASSPATH) - cd python && $(HAIL_PYTHON3) -m pytest \ + cd python && \ + $(HAIL_PYTHON3) -m pytest \ -Werror:::hail -Werror:::hailtop -Werror::ResourceWarning \ --log-cli-level=INFO \ -s \ @@ -191,20 +194,49 @@ pytest: python/README.md $(FAST_PYTHON_JAR) $(FAST_PYTHON_JAR_EXTRA_CLASSPATH) --self-contained-html \ --html=../build/reports/pytest.html \ --timeout=120 \ + --ignore $(PYTEST_INTER_CLOUD_TARGET) \ $(PYTEST_TARGET) \ $(PYTEST_ARGS) -# NOTE: Look at upload-qob-test-resources target if test resources are missing + +# NOTE: Look at upload-remote-test-resources target if test resources are missing +.PHONY: pytest-inter-cloud +pytest-inter-cloud: $(PYTHON_VERSION_INFO) $(INIT_SCRIPTS) +pytest-inter-cloud: python/README.md $(FAST_PYTHON_JAR) $(FAST_PYTHON_JAR_EXTRA_CLASSPATH) +pytest-inter-cloud: upload-remote-test-resources + cd python && \ + HAIL_TEST_STORAGE_URI=$(TEST_STORAGE_URI) \ + HAIL_TEST_GCS_BUCKET=$(HAIL_TEST_GCS_BUCKET) \ + HAIL_TEST_S3_BUCKET=hail-test-dy5rg \ + HAIL_TEST_AZURE_ACCOUNT=hailtest \ + HAIL_TEST_AZURE_CONTAINER=hail-test-4nxei \ + $(HAIL_PYTHON3) -m pytest \ + -Werror:::hail -Werror:::hailtop -Werror::ResourceWarning \ + --log-cli-level=INFO \ + -s \ + -vv \ + -r A \ + --instafail \ + --durations=50 \ + --self-contained-html \ + --html=../build/reports/pytest.html \ + --timeout=120 \ + $(PYTEST_INTER_CLOUD_TARGET) \ + $(PYTEST_ARGS) + + +# NOTE: Look at upload-remote-test-resources target if test resources are missing .PHONY: pytest-qob -pytest-qob: upload-qob-jar upload-qob-test-resources install-editable +pytest-qob: upload-qob-jar upload-remote-test-resources install-editable ! [ -z $(NAMESPACE) ] # call this like: make pytest-qob NAMESPACE=default cd python && \ - HAIL_QUERY_BACKEND=batch \ - HAIL_QUERY_JAR_URL=$$(cat ../upload-qob-jar) \ - HAIL_DEFAULT_NAMESPACE=$(NAMESPACE) \ - HAIL_TEST_RESOURCES_DIR='$(CLOUD_HAIL_TEST_RESOURCES_DIR)' \ - HAIL_DOCTEST_DATA_DIR='$(HAIL_DOCTEST_DATA_DIR)' \ - $(HAIL_PYTHON3) -m pytest \ + HAIL_TEST_STORAGE_URI=$(TEST_STORAGE_URI) \ + HAIL_QUERY_BACKEND=batch \ + HAIL_QUERY_JAR_URL=$$(cat ../upload-qob-jar) \ + HAIL_DEFAULT_NAMESPACE=$(NAMESPACE) \ + HAIL_TEST_RESOURCES_DIR='$(CLOUD_HAIL_TEST_RESOURCES_DIR)' \ + HAIL_DOCTEST_DATA_DIR='$(HAIL_DOCTEST_DATA_DIR)' \ + $(HAIL_PYTHON3) -m pytest \ -Werror:::hail -Werror:::hailtop -Werror::ResourceWarning \ --log-cli-level=INFO \ -s \ @@ -312,10 +344,9 @@ upload-artifacts: $(WHEEL) # NOTE: 1-day expiration of the test bucket means that this # target must be run at least once a day. To trigger this target to re-run, -# > rm upload-qob-test-resources -upload-qob-test-resources: $(shell git ls-files src/test/resources) -upload-qob-test-resources: $(shell git ls-files python/hail/docs/data) - ! [ -z $(NAMESPACE) ] # call this like: make upload-qob-test-resources NAMESPACE=default +# > rm upload-remote-test-resources +upload-remote-test-resources: $(shell git ls-files src/test/resources) +upload-remote-test-resources: $(shell git ls-files python/hail/docs/data) gcloud storage cp -r src/test/resources/\* $(CLOUD_HAIL_TEST_RESOURCES_DIR) gcloud storage cp -r python/hail/docs/data/\* $(CLOUD_HAIL_DOCTEST_DATA_DIR) # # In Azure, use the following instead of gcloud storage cp @@ -323,7 +354,7 @@ upload-qob-test-resources: $(shell git ls-files python/hail/docs/data) # {"from":"src/test/resources","to":"$(CLOUD_HAIL_TEST_RESOURCES_DIR)"},\ # {"from":"python/hail/docs/data","to":"$(CLOUD_HAIL_DOCTEST_DATA_DIR)"}\ # ]' - # touch $@ + touch $@ # NOTE: 1-day expiration of the test bucket means that this # target must be run at least once a day if using a dev NAMESPACE. diff --git a/hail/python/dev/requirements.txt b/hail/python/dev/requirements.txt index 6e9a9df42a3..31c191e1003 100644 --- a/hail/python/dev/requirements.txt +++ b/hail/python/dev/requirements.txt @@ -11,7 +11,8 @@ pytest>=7.1.3,<8 pytest-html>=1.20.0,<2 pytest-xdist>=2.2.1,<3 pytest-instafail>=0.4.2,<1 -pytest-asyncio>=0.14.0,<1 +# https://github.com/hail-is/hail/issues/14130 +pytest-asyncio>=0.14.0,<0.23 pytest-timestamper>=0.0.9,<1 pytest-timeout>=2.1,<3 pyright>=1.1.324<1.2 diff --git a/hail/python/hailtop/aiocloud/aioaws/fs.py b/hail/python/hailtop/aiocloud/aioaws/fs.py index 8aae5b16871..5c5fa31204f 100644 --- a/hail/python/hailtop/aiocloud/aioaws/fs.py +++ b/hail/python/hailtop/aiocloud/aioaws/fs.py @@ -353,12 +353,19 @@ def __init__( ) self._s3 = boto3.client('s3', config=config) + @staticmethod + def copy_part_size(url: str) -> int: # pylint: disable=unused-argument + # Because the S3 upload_part API call requires the entire part + # be loaded into memory, use a smaller part size. + return 32 * 1024 * 1024 + @staticmethod def valid_url(url: str) -> bool: return url.startswith('s3://') - def parse_url(self, url: str) -> S3AsyncFSURL: - return S3AsyncFSURL(*self.get_bucket_and_name(url)) + @staticmethod + def parse_url(url: str) -> S3AsyncFSURL: + return S3AsyncFSURL(*S3AsyncFS.get_bucket_and_name(url)) @staticmethod def get_bucket_and_name(url: str) -> Tuple[str, str]: @@ -565,8 +572,3 @@ async def remove(self, url: str) -> None: async def close(self) -> None: del self._s3 - - def copy_part_size(self, url: str) -> int: # pylint: disable=unused-argument - # Because the S3 upload_part API call requires the entire part - # be loaded into memory, use a smaller part size. - return 32 * 1024 * 1024 diff --git a/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py b/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py index df642fabdff..92c9275c179 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py @@ -636,8 +636,9 @@ async def is_hot_storage(self, location: str) -> bool: def valid_url(url: str) -> bool: return url.startswith('gs://') - def parse_url(self, url: str) -> GoogleStorageAsyncFSURL: - return GoogleStorageAsyncFSURL(*self.get_bucket_and_name(url)) + @staticmethod + def parse_url(url: str) -> GoogleStorageAsyncFSURL: + return GoogleStorageAsyncFSURL(*GoogleStorageAsyncFS.get_bucket_and_name(url)) @staticmethod def get_bucket_and_name(url: str) -> Tuple[str, str]: diff --git a/hail/python/hailtop/aiotools/fs/fs.py b/hail/python/hailtop/aiotools/fs/fs.py index fef3ff9833b..d55dbb0b447 100644 --- a/hail/python/hailtop/aiotools/fs/fs.py +++ b/hail/python/hailtop/aiotools/fs/fs.py @@ -154,13 +154,20 @@ class AsyncFS(abc.ABC): def schemes(self) -> Set[str]: pass + @staticmethod + def copy_part_size(url: str) -> int: # pylint: disable=unused-argument + '''Part size when copying using multi-part uploads. The part size of + the destination filesystem is used.''' + return 128 * 1024 * 1024 + @staticmethod @abc.abstractmethod def valid_url(url: str) -> bool: pass + @staticmethod @abc.abstractmethod - def parse_url(self, url: str) -> AsyncFSURL: + def parse_url(url: str) -> AsyncFSURL: pass @abc.abstractmethod @@ -326,11 +333,6 @@ async def __aexit__( ) -> None: await self.close() - def copy_part_size(self, url: str) -> int: # pylint: disable=unused-argument - '''Part size when copying using multi-part uploads. The part size of - the destination filesystem is used.''' - return 128 * 1024 * 1024 - T = TypeVar('T', bound=AsyncFS) diff --git a/hail/python/hailtop/aiotools/local_fs.py b/hail/python/hailtop/aiotools/local_fs.py index 576c9ca00a0..df4248def42 100644 --- a/hail/python/hailtop/aiotools/local_fs.py +++ b/hail/python/hailtop/aiotools/local_fs.py @@ -236,8 +236,9 @@ def __init__(self, thread_pool: Optional[ThreadPoolExecutor] = None, max_workers def valid_url(url: str) -> bool: return url.startswith('file://') or '://' not in url - def parse_url(self, url: str) -> LocalAsyncFSURL: - return LocalAsyncFSURL(self._get_path(url)) + @staticmethod + def parse_url(url: str) -> LocalAsyncFSURL: + return LocalAsyncFSURL(LocalAsyncFS._get_path(url)) @staticmethod def _get_path(url): diff --git a/hail/python/hailtop/aiotools/router_fs.py b/hail/python/hailtop/aiotools/router_fs.py index 3ef88285377..64df740d94a 100644 --- a/hail/python/hailtop/aiotools/router_fs.py +++ b/hail/python/hailtop/aiotools/router_fs.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, List, Set, AsyncIterator, Dict, AsyncContextManager, Callable +from typing import Any, Optional, List, Set, AsyncIterator, Dict, AsyncContextManager, Callable, Type import asyncio from ..aiocloud import aioaws, aioazure, aiogoogle @@ -30,8 +30,27 @@ def __init__( else configuration_of(ConfigVariable.GCS_BUCKET_ALLOW_LIST, None, fallback="").split(",") ) - def parse_url(self, url: str) -> AsyncFSURL: - return self._get_fs(url).parse_url(url) + @staticmethod + def copy_part_size(url: str) -> int: + klass = RouterAsyncFS._fs_class(url) + return klass.copy_part_size(url) + + @staticmethod + def parse_url(url: str) -> AsyncFSURL: + klass = RouterAsyncFS._fs_class(url) + return klass.parse_url(url) + + @staticmethod + def _fs_class(url: str) -> Type[AsyncFS]: + if LocalAsyncFS.valid_url(url): + return LocalAsyncFS + if aiogoogle.GoogleStorageAsyncFS.valid_url(url): + return aiogoogle.GoogleStorageAsyncFS + if aioazure.AzureAsyncFS.valid_url(url): + return aioazure.AzureAsyncFS + if aioaws.S3AsyncFS.valid_url(url): + return aioaws.S3AsyncFS + raise ValueError(f'no file system found for url {url}') @property def schemes(self) -> Set[str]: @@ -46,7 +65,8 @@ def valid_url(url) -> bool: or aioaws.S3AsyncFS.valid_url(url) ) - def _load_fs(self, uri: str): + async def _load_fs(self, uri: str): # async ensures a running loop which is required + # by aiohttp which is used by many AsyncFSes fs: AsyncFS if LocalAsyncFS.valid_url(uri): @@ -65,72 +85,69 @@ def _load_fs(self, uri: str): self._filesystems.append(fs) return fs - def _get_fs(self, uri: str) -> AsyncFS: + async def _get_fs(self, url: str) -> AsyncFS: # async ensures a running loop which is required + # by aiohttp which is used by many AsyncFSes for fs in self._filesystems: - if fs.valid_url(uri): + if fs.valid_url(url): return fs - return self._load_fs(uri) + return await self._load_fs(url) async def open(self, url: str) -> ReadableStream: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.open(url) async def _open_from(self, url: str, start: int, *, length: Optional[int] = None) -> ReadableStream: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.open_from(url, start, length=length) async def create(self, url: str, retry_writes: bool = True) -> AsyncContextManager[WritableStream]: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.create(url, retry_writes=retry_writes) async def multi_part_create(self, sema: asyncio.Semaphore, url: str, num_parts: int) -> MultiPartCreate: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.multi_part_create(sema, url, num_parts) async def statfile(self, url: str) -> FileStatus: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.statfile(url) async def listfiles( self, url: str, recursive: bool = False, exclude_trailing_slash_files: bool = True ) -> AsyncIterator[FileListEntry]: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.listfiles(url, recursive, exclude_trailing_slash_files) async def staturl(self, url: str) -> str: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.staturl(url) async def mkdir(self, url: str) -> None: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.mkdir(url) async def makedirs(self, url: str, exist_ok: bool = False) -> None: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.makedirs(url, exist_ok=exist_ok) async def isfile(self, url: str) -> bool: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.isfile(url) async def isdir(self, url: str) -> bool: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.isdir(url) async def remove(self, url: str) -> None: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.remove(url) async def rmtree( self, sema: Optional[asyncio.Semaphore], url: str, listener: Optional[Callable[[int], None]] = None ) -> None: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.rmtree(sema, url, listener) async def close(self) -> None: for fs in self._filesystems: await fs.close() - - def copy_part_size(self, url: str) -> int: - fs = self._get_fs(url) - return fs.copy_part_size(url) diff --git a/hail/python/hailtop/aiotools/validators.py b/hail/python/hailtop/aiotools/validators.py index fa2d5aa3794..e492446f7c4 100644 --- a/hail/python/hailtop/aiotools/validators.py +++ b/hail/python/hailtop/aiotools/validators.py @@ -1,6 +1,6 @@ +from hailtop.hail_event_loop import hail_event_loop from hailtop.aiocloud.aiogoogle.client.storage_client import GoogleStorageAsyncFS from hailtop.aiotools.router_fs import RouterAsyncFS -from hailtop.utils import async_to_blocking from textwrap import dedent from typing import Optional from urllib.parse import urlparse @@ -16,6 +16,14 @@ def validate_file(uri: str, router_async_fs: RouterAsyncFS, *, validate_scheme: :class:`ValueError` If one of the validation steps fails. """ + return hail_event_loop().run_until_complete( + _async_validate_file(uri, router_async_fs, validate_scheme=validate_scheme) + ) + + +async def _async_validate_file( + uri: str, router_async_fs: RouterAsyncFS, *, validate_scheme: Optional[bool] = False +) -> None: if validate_scheme: scheme = urlparse(uri).scheme if not scheme or scheme == "file": @@ -23,11 +31,11 @@ def validate_file(uri: str, router_async_fs: RouterAsyncFS, *, validate_scheme: f"Local filepath detected: '{uri}'. The Hail Batch Service does not support the use of local " "filepaths. Please specify a remote URI instead (e.g. 'gs://bucket/folder')." ) - fs = router_async_fs._get_fs(uri) + fs = await router_async_fs._get_fs(uri) if isinstance(fs, GoogleStorageAsyncFS): location = fs.storage_location(uri) if location not in fs.allowed_storage_locations: - if not async_to_blocking(fs.is_hot_storage(location)): + if not await fs.is_hot_storage(location): raise ValueError( dedent( f"""\ diff --git a/hail/python/hailtop/fs/router_fs.py b/hail/python/hailtop/fs/router_fs.py index 4c0aa49c5f9..ca4a9512ef8 100644 --- a/hail/python/hailtop/fs/router_fs.py +++ b/hail/python/hailtop/fs/router_fs.py @@ -7,7 +7,6 @@ import fnmatch from hailtop.aiotools.fs import Copier, Transfer, FileListEntry as AIOFileListEntry, ReadableStream, WritableStream -from hailtop.aiotools.local_fs import LocalAsyncFS from hailtop.aiotools.router_fs import RouterAsyncFS from hailtop.utils import bounded_gather2, async_to_blocking @@ -411,7 +410,8 @@ def supports_scheme(self, scheme: str) -> bool: return scheme in self.afs.schemes def canonicalize_path(self, path: str) -> str: - if isinstance(self.afs._get_fs(path), LocalAsyncFS): + url = self.afs.parse_url(path) + if url.scheme == 'file': if path.startswith('file:'): return 'file:' + os.path.realpath(path[5:]) return 'file:' + os.path.realpath(path) diff --git a/hail/python/hailtop/hail_event_loop.py b/hail/python/hailtop/hail_event_loop.py index 229232604eb..eaf010291e3 100644 --- a/hail/python/hailtop/hail_event_loop.py +++ b/hail/python/hailtop/hail_event_loop.py @@ -2,15 +2,17 @@ import nest_asyncio -def hail_event_loop(): +def hail_event_loop() -> asyncio.AbstractEventLoop: '''If a running event loop exists, use nest_asyncio to allow Hail's event loops to nest inside it. If no event loop exists, ask asyncio to get one for us. ''' try: - asyncio.get_running_loop() - nest_asyncio.apply() - return asyncio.get_running_loop() + loop = asyncio.get_event_loop() except RuntimeError: - return asyncio.get_event_loop() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + nest_asyncio.apply(loop) + return loop diff --git a/hail/python/hailtop/httpx.py b/hail/python/hailtop/httpx.py index 01f845940b8..41a9dea053a 100644 --- a/hail/python/hailtop/httpx.py +++ b/hail/python/hailtop/httpx.py @@ -100,6 +100,7 @@ def __init__( if timeout is None: timeout = aiohttp.ClientTimeout(total=5) + self.loop = asyncio.get_running_loop() self.raise_for_status = raise_for_status self.client_session = aiohttp.ClientSession( *args, timeout=timeout, raise_for_status=False, connector=aiohttp.TCPConnector(ssl=tls), **kwargs @@ -108,6 +109,10 @@ def __init__( def request( self, method: str, url: aiohttp.typedefs.StrOrURL, **kwargs: Any ) -> aiohttp.client._RequestContextManager: + if self.loop != asyncio.get_running_loop(): + raise ValueError( + f'ClientSession must be created and used in same loop {self.loop} != {asyncio.get_running_loop()}.' + ) raise_for_status = kwargs.pop('raise_for_status', self.raise_for_status) async def request_and_raise_for_status(): diff --git a/hail/python/test/hail/conftest.py b/hail/python/test/hail/conftest.py index 933b522181c..3423e6919b2 100644 --- a/hail/python/test/hail/conftest.py +++ b/hail/python/test/hail/conftest.py @@ -7,6 +7,7 @@ import pytest from pytest import StashKey, CollectReport + from hail import current_backend, init, reset_global_randomness from hail.backend.service_backend import ServiceBackend from hailtop.hail_event_loop import hail_event_loop @@ -17,7 +18,16 @@ log = logging.getLogger(__name__) -def pytest_collection_modifyitems(config, items): +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop() + try: + yield loop + finally: + loop.close() + + +def pytest_collection_modifyitems(items): n_splits = int(os.environ.get('HAIL_RUN_IMAGE_SPLITS', '1')) split_index = int(os.environ.get('HAIL_RUN_IMAGE_SPLIT_INDEX', '-1')) if n_splits <= 1: @@ -34,15 +44,6 @@ def digest(s): item.add_marker(skip_this) -@pytest.fixture(scope="session", autouse=True) -def ensure_event_loop_is_initialized_in_test_thread(): - try: - asyncio.get_running_loop() - except RuntimeError as err: - assert err.args[0] == "no running event loop" - asyncio.set_event_loop(asyncio.new_event_loop()) - - @pytest.fixture(scope="session", autouse=True) def init_hail(): hl_init_for_test() diff --git a/hail/python/test/hailtop/batch/test_batch.py b/hail/python/test/hailtop/batch/test_batch.py deleted file mode 100644 index c9eeeba9fda..00000000000 --- a/hail/python/test/hailtop/batch/test_batch.py +++ /dev/null @@ -1,1519 +0,0 @@ -import asyncio -import inspect -import secrets -import unittest - -import pytest -import os -import subprocess as sp -import tempfile -from shlex import quote as shq -import uuid -import re -import orjson - -import hailtop.fs as hfs -import hailtop.batch_client.client as bc -from hailtop import pip_version -from hailtop.batch import Batch, ServiceBackend, LocalBackend, ResourceGroup -from hailtop.batch.resource import JobResourceFile -from hailtop.batch.exceptions import BatchException -from hailtop.batch.globals import arg_max -from hailtop.utils import grouped, async_to_blocking -from hailtop.config import get_remote_tmpdir, configuration_of -from hailtop.batch.utils import concatenate -from hailtop.aiotools.router_fs import RouterAsyncFS -from hailtop.test_utils import skip_in_azure -from hailtop.httpx import ClientResponseError - -from configparser import ConfigParser -from hailtop.config import get_user_config, user_config -from hailtop.config.variables import ConfigVariable -from _pytest.monkeypatch import MonkeyPatch - - -DOCKER_ROOT_IMAGE = os.environ.get('DOCKER_ROOT_IMAGE', 'ubuntu:22.04') -PYTHON_DILL_IMAGE = 'hailgenetics/python-dill:3.9-slim' -HAIL_GENETICS_HAIL_IMAGE = os.environ.get('HAIL_GENETICS_HAIL_IMAGE', f'hailgenetics/hail:{pip_version()}') -REQUESTER_PAYS_PROJECT = os.environ.get('GCS_REQUESTER_PAYS_PROJECT') - - -class LocalTests(unittest.TestCase): - def batch(self, requester_pays_project=None): - return Batch(backend=LocalBackend(), requester_pays_project=requester_pays_project) - - def read(self, file): - with open(file, 'r') as f: - result = f.read().rstrip() - return result - - def assert_same_file(self, file1, file2): - assert self.read(file1).rstrip() == self.read(file2).rstrip() - - def test_read_input_and_write_output(self): - with tempfile.NamedTemporaryFile('w') as input_file, tempfile.NamedTemporaryFile('w') as output_file: - input_file.write('abc') - input_file.flush() - - b = self.batch() - input = b.read_input(input_file.name) - b.write_output(input, output_file.name) - b.run() - - self.assert_same_file(input_file.name, output_file.name) - - def test_read_input_group(self): - with tempfile.NamedTemporaryFile('w') as input_file1, tempfile.NamedTemporaryFile( - 'w' - ) as input_file2, tempfile.NamedTemporaryFile('w') as output_file1, tempfile.NamedTemporaryFile( - 'w' - ) as output_file2: - - input_file1.write('abc') - input_file2.write('123') - input_file1.flush() - input_file2.flush() - - b = self.batch() - input = b.read_input_group(in1=input_file1.name, in2=input_file2.name) - - b.write_output(input.in1, output_file1.name) - b.write_output(input.in2, output_file2.name) - b.run() - - self.assert_same_file(input_file1.name, output_file1.name) - self.assert_same_file(input_file2.name, output_file2.name) - - def test_write_resource_group(self): - with tempfile.NamedTemporaryFile('w') as input_file1, tempfile.NamedTemporaryFile( - 'w' - ) as input_file2, tempfile.TemporaryDirectory() as output_dir: - - b = self.batch() - input = b.read_input_group(in1=input_file1.name, in2=input_file2.name) - - b.write_output(input, output_dir + '/foo') - b.run() - - self.assert_same_file(input_file1.name, output_dir + '/foo.in1') - self.assert_same_file(input_file2.name, output_dir + '/foo.in2') - - def test_single_job(self): - with tempfile.NamedTemporaryFile('w') as output_file: - msg = 'hello world' - - b = self.batch() - j = b.new_job() - j.command(f'echo "{msg}" > {j.ofile}') - b.write_output(j.ofile, output_file.name) - b.run() - - assert self.read(output_file.name) == msg - - def test_single_job_with_shell(self): - with tempfile.NamedTemporaryFile('w') as output_file: - msg = 'hello world' - - b = self.batch() - j = b.new_job(shell='/bin/bash') - j.command(f'echo "{msg}" > {j.ofile}') - - b.write_output(j.ofile, output_file.name) - b.run() - - assert self.read(output_file.name) == msg - - def test_single_job_with_nonsense_shell(self): - b = self.batch() - j = b.new_job(shell='/bin/ajdsfoijasidojf') - j.image(DOCKER_ROOT_IMAGE) - j.command(f'echo "hello"') - self.assertRaises(Exception, b.run) - - b = self.batch() - j = b.new_job(shell='/bin/nonexistent') - j.command(f'echo "hello"') - self.assertRaises(Exception, b.run) - - def test_single_job_with_intermediate_failure(self): - b = self.batch() - j = b.new_job() - j.command(f'echoddd "hello"') - j2 = b.new_job() - j2.command(f'echo "world"') - - self.assertRaises(Exception, b.run) - - def test_single_job_w_input(self): - with tempfile.NamedTemporaryFile('w') as input_file, tempfile.NamedTemporaryFile('w') as output_file: - msg = 'abc' - input_file.write(msg) - input_file.flush() - - b = self.batch() - input = b.read_input(input_file.name) - j = b.new_job() - j.command(f'cat {input} > {j.ofile}') - b.write_output(j.ofile, output_file.name) - b.run() - - assert self.read(output_file.name) == msg - - def test_single_job_w_input_group(self): - with tempfile.NamedTemporaryFile('w') as input_file1, tempfile.NamedTemporaryFile( - 'w' - ) as input_file2, tempfile.NamedTemporaryFile('w') as output_file: - msg1 = 'abc' - msg2 = '123' - - input_file1.write(msg1) - input_file2.write(msg2) - input_file1.flush() - input_file2.flush() - - b = self.batch() - input = b.read_input_group(in1=input_file1.name, in2=input_file2.name) - j = b.new_job() - j.command(f'cat {input.in1} {input.in2} > {j.ofile}') - j.command(f'cat {input}.in1 {input}.in2') - b.write_output(j.ofile, output_file.name) - b.run() - - assert self.read(output_file.name) == msg1 + msg2 - - def test_single_job_bad_command(self): - b = self.batch() - j = b.new_job() - j.command("foo") # this should fail! - with self.assertRaises(sp.CalledProcessError): - b.run() - - def test_declare_resource_group(self): - with tempfile.NamedTemporaryFile('w') as output_file: - msg = 'hello world' - b = self.batch() - j = b.new_job() - j.declare_resource_group(ofile={'log': "{root}.txt"}) - assert isinstance(j.ofile, ResourceGroup) - j.command(f'echo "{msg}" > {j.ofile.log}') - b.write_output(j.ofile.log, output_file.name) - b.run() - - assert self.read(output_file.name) == msg - - def test_resource_group_get_all_inputs(self): - b = self.batch() - input = b.read_input_group(fasta="foo", idx="bar") - j = b.new_job() - j.command(f"cat {input.fasta}") - assert input.fasta in j._inputs - assert input.idx in j._inputs - - def test_resource_group_get_all_mentioned(self): - b = self.batch() - j = b.new_job() - j.declare_resource_group(foo={'bed': '{root}.bed', 'bim': '{root}.bim'}) - assert isinstance(j.foo, ResourceGroup) - j.command(f"cat {j.foo.bed}") - assert j.foo.bed in j._mentioned - assert j.foo.bim not in j._mentioned - - def test_resource_group_get_all_mentioned_dependent_jobs(self): - b = self.batch() - j = b.new_job() - j.declare_resource_group(foo={'bed': '{root}.bed', 'bim': '{root}.bim'}) - j.command(f"cat") - j2 = b.new_job() - j2.command(f"cat {j.foo}") - - def test_resource_group_get_all_outputs(self): - b = self.batch() - j1 = b.new_job() - j1.declare_resource_group(foo={'bed': '{root}.bed', 'bim': '{root}.bim'}) - assert isinstance(j1.foo, ResourceGroup) - j1.command(f"cat {j1.foo.bed}") - j2 = b.new_job() - j2.command(f"cat {j1.foo.bed}") - - for r in [j1.foo.bed, j1.foo.bim]: - assert r in j1._internal_outputs - assert r in j2._inputs - - assert j1.foo.bed in j1._mentioned - assert j1.foo.bim not in j1._mentioned - - assert j1.foo.bed in j2._mentioned - assert j1.foo.bim not in j2._mentioned - - assert j1.foo not in j1._mentioned - - def test_multiple_isolated_jobs(self): - b = self.batch() - - output_files = [] - try: - output_files = [tempfile.NamedTemporaryFile('w') for _ in range(5)] - - for i, ofile in enumerate(output_files): - msg = f'hello world {i}' - j = b.new_job() - j.command(f'echo "{msg}" > {j.ofile}') - b.write_output(j.ofile, ofile.name) - b.run() - - for i, ofile in enumerate(output_files): - msg = f'hello world {i}' - assert self.read(ofile.name) == msg - finally: - [ofile.close() for ofile in output_files] - - def test_multiple_dependent_jobs(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - j = b.new_job() - j.command(f'echo "0" >> {j.ofile}') - - for i in range(1, 3): - j2 = b.new_job() - j2.command(f'echo "{i}" > {j2.tmp1}') - j2.command(f'cat {j.ofile} {j2.tmp1} > {j2.ofile}') - j = j2 - - b.write_output(j.ofile, output_file.name) - b.run() - - assert self.read(output_file.name) == "0\n1\n2" - - def test_select_jobs(self): - b = self.batch() - for i in range(3): - b.new_job(name=f'foo{i}') - self.assertTrue(len(b.select_jobs('foo')) == 3) - - def test_scatter_gather(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - - for i in range(3): - j = b.new_job(name=f'foo{i}') - j.command(f'echo "{i}" > {j.ofile}') - - merger = b.new_job() - merger.command( - 'cat {files} > {ofile}'.format( - files=' '.join( - [ - j.ofile - for j in sorted(b.select_jobs('foo'), key=lambda x: x.name, reverse=True) # type: ignore - ] - ), - ofile=merger.ofile, - ) - ) - - b.write_output(merger.ofile, output_file.name) - b.run() - - assert self.read(output_file.name) == '2\n1\n0' - - def test_add_extension_job_resource_file(self): - b = self.batch() - j = b.new_job() - j.command(f'echo "hello" > {j.ofile}') - assert isinstance(j.ofile, JobResourceFile) - j.ofile.add_extension('.txt.bgz') - assert j.ofile._value - assert j.ofile._value.endswith('.txt.bgz') - - def test_add_extension_input_resource_file(self): - input_file1 = '/tmp/data/example1.txt.bgz.foo' - b = self.batch() - in1 = b.read_input(input_file1) - assert in1._value - assert in1._value.endswith('.txt.bgz.foo') - - def test_file_name_space(self): - with tempfile.NamedTemporaryFile( - 'w', prefix="some file name with (foo) spaces" - ) as input_file, tempfile.NamedTemporaryFile('w', prefix="another file name with (foo) spaces") as output_file: - - input_file.write('abc') - input_file.flush() - - b = self.batch() - input = b.read_input(input_file.name) - j = b.new_job() - j.command(f'cat {input} > {j.ofile}') - b.write_output(j.ofile, output_file.name) - b.run() - - self.assert_same_file(input_file.name, output_file.name) - - def test_resource_group_mentioned(self): - b = self.batch() - j = b.new_job() - j.declare_resource_group(foo={'bed': '{root}.bed'}) - assert isinstance(j.foo, ResourceGroup) - j.command(f'echo "hello" > {j.foo}') - - t2 = b.new_job() - t2.command(f'echo "hello" >> {j.foo.bed}') - b.run() - - def test_envvar(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - j = b.new_job() - j.env('SOME_VARIABLE', '123abcdef') - j.command(f'echo $SOME_VARIABLE > {j.ofile}') - b.write_output(j.ofile, output_file.name) - b.run() - assert self.read(output_file.name) == '123abcdef' - - def test_concatenate(self): - b = self.batch() - files = [] - for _ in range(10): - j = b.new_job() - j.command(f'touch {j.ofile}') - files.append(j.ofile) - concatenate(b, files, branching_factor=2) - assert len(b._jobs) == 10 + (5 + 3 + 2 + 1) - b.run() - - def test_python_job(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - head = b.new_job() - head.command(f'echo "5" > {head.r5}') - head.command(f'echo "3" > {head.r3}') - - def read(path): - with open(path, 'r') as f: - i = f.read() - return int(i) - - def multiply(x, y): - return x * y - - def reformat(x, y): - return {'x': x, 'y': y} - - middle = b.new_python_job() - r3 = middle.call(read, head.r3) - r5 = middle.call(read, head.r5) - r_mult = middle.call(multiply, r3, r5) - - middle2 = b.new_python_job() - r_mult = middle2.call(multiply, r_mult, 2) - r_dict = middle2.call(reformat, r3, r5) - - tail = b.new_job() - tail.command(f'cat {r3.as_str()} {r5.as_repr()} {r_mult.as_str()} {r_dict.as_json()} > {tail.ofile}') - - b.write_output(tail.ofile, output_file.name) - b.run() - assert self.read(output_file.name) == '3\n5\n30\n{\"x\": 3, \"y\": 5}' - - def test_backend_context_manager(self): - with LocalBackend() as backend: - b = Batch(backend=backend) - b.run() - - def test_failed_jobs_dont_stop_non_dependent_jobs(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - - head = b.new_job() - head.command(f'echo 1 > {head.ofile}') - - head2 = b.new_job() - head2.command('false') - - tail = b.new_job() - tail.command(f'cat {head.ofile} > {tail.ofile}') - b.write_output(tail.ofile, output_file.name) - self.assertRaises(Exception, b.run) - assert self.read(output_file.name) == '1' - - def test_failed_jobs_stop_child_jobs(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - - head = b.new_job() - head.command(f'echo 1 > {head.ofile}') - head.command('false') - - head2 = b.new_job() - head2.command(f'echo 2 > {head2.ofile}') - - tail = b.new_job() - tail.command(f'cat {head.ofile} > {tail.ofile}') - - b.write_output(head2.ofile, output_file.name) - b.write_output(tail.ofile, output_file.name) - self.assertRaises(Exception, b.run) - assert self.read(output_file.name) == '2' - - def test_failed_jobs_stop_grandchild_jobs(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - - head = b.new_job() - head.command(f'echo 1 > {head.ofile}') - head.command('false') - - head2 = b.new_job() - head2.command(f'echo 2 > {head2.ofile}') - - tail = b.new_job() - tail.command(f'cat {head.ofile} > {tail.ofile}') - - tail2 = b.new_job() - tail2.depends_on(tail) - tail2.command(f'echo foo > {tail2.ofile}') - - b.write_output(head2.ofile, output_file.name) - b.write_output(tail2.ofile, output_file.name) - self.assertRaises(Exception, b.run) - assert self.read(output_file.name) == '2' - - def test_failed_jobs_dont_stop_always_run_jobs(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - - head = b.new_job() - head.command(f'echo 1 > {head.ofile}') - head.command('false') - - tail = b.new_job() - tail.command(f'cat {head.ofile} > {tail.ofile}') - tail.always_run() - - b.write_output(tail.ofile, output_file.name) - self.assertRaises(Exception, b.run) - assert self.read(output_file.name) == '1' - - -class ServiceTests(unittest.TestCase): - def setUp(self): - # https://stackoverflow.com/questions/42332030/pytest-monkeypatch-setattr-inside-of-test-class-method - self.monkeypatch = MonkeyPatch() - - self.backend = ServiceBackend() - - remote_tmpdir = get_remote_tmpdir('hailtop_test_batch_service_tests') - if not remote_tmpdir.endswith('/'): - remote_tmpdir += '/' - self.remote_tmpdir = remote_tmpdir + str(uuid.uuid4()) + '/' - - if remote_tmpdir.startswith('gs://'): - match = re.fullmatch('gs://(?P[^/]+).*', remote_tmpdir) - assert match - self.bucket = match.groupdict()['bucket_name'] - else: - assert remote_tmpdir.startswith('hail-az://') - if remote_tmpdir.startswith('hail-az://'): - match = re.fullmatch('hail-az://(?P[^/]+)/(?P[^/]+).*', remote_tmpdir) - assert match - storage_account, container_name = match.groups() - else: - assert remote_tmpdir.startswith('https://') - match = re.fullmatch( - 'https://(?P[^/]+).blob.core.windows.net/(?P[^/]+).*', - remote_tmpdir, - ) - assert match - storage_account, container_name = match.groups() - self.bucket = f'{storage_account}/{container_name}' - - self.cloud_input_dir = f'{self.remote_tmpdir}batch-tests/resources' - - token = uuid.uuid4() - self.cloud_output_path = f'/batch-tests/{token}' - self.cloud_output_dir = f'{self.remote_tmpdir}{self.cloud_output_path}' - - self.router_fs = RouterAsyncFS() - - if not self.sync_exists(f'{self.remote_tmpdir}batch-tests/resources/hello.txt'): - self.sync_write(f'{self.remote_tmpdir}batch-tests/resources/hello.txt', b'hello world') - if not self.sync_exists(f'{self.remote_tmpdir}batch-tests/resources/hello spaces.txt'): - self.sync_write(f'{self.remote_tmpdir}batch-tests/resources/hello spaces.txt', b'hello') - if not self.sync_exists(f'{self.remote_tmpdir}batch-tests/resources/hello (foo) spaces.txt'): - self.sync_write(f'{self.remote_tmpdir}batch-tests/resources/hello (foo) spaces.txt', b'hello') - - def tearDown(self): - self.backend.close() - - def sync_exists(self, url): - return async_to_blocking(self.router_fs.exists(url)) - - def sync_write(self, url, data): - return async_to_blocking(self.router_fs.write(url, data)) - - def batch(self, **kwargs): - name_of_test_method = inspect.stack()[1][3] - return Batch( - name=name_of_test_method, - backend=self.backend, - default_image=DOCKER_ROOT_IMAGE, - attributes={'foo': 'a', 'bar': 'b'}, - **kwargs, - ) - - def test_single_task_no_io(self): - b = self.batch() - j = b.new_job() - j.command('echo hello') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_task_input(self): - b = self.batch() - input = b.read_input(f'{self.cloud_input_dir}/hello.txt') - j = b.new_job() - j.command(f'cat {input}') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_task_input_resource_group(self): - b = self.batch() - input = b.read_input_group(foo=f'{self.cloud_input_dir}/hello.txt') - j = b.new_job() - j.storage('10Gi') - j.command(f'cat {input.foo}') - j.command(f'cat {input}.foo') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_task_output(self): - b = self.batch() - j = b.new_job(attributes={'a': 'bar', 'b': 'foo'}) - j.command(f'echo hello > {j.ofile}') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_task_write_output(self): - b = self.batch() - j = b.new_job() - j.command(f'echo hello > {j.ofile}') - b.write_output(j.ofile, f'{self.cloud_output_dir}/test_single_task_output.txt') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_task_resource_group(self): - b = self.batch() - j = b.new_job() - j.declare_resource_group(output={'foo': '{root}.foo'}) - assert isinstance(j.output, ResourceGroup) - j.command(f'echo "hello" > {j.output.foo}') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_task_write_resource_group(self): - b = self.batch() - j = b.new_job() - j.declare_resource_group(output={'foo': '{root}.foo'}) - assert isinstance(j.output, ResourceGroup) - j.command(f'echo "hello" > {j.output.foo}') - b.write_output(j.output, f'{self.cloud_output_dir}/test_single_task_write_resource_group') - b.write_output(j.output.foo, f'{self.cloud_output_dir}/test_single_task_write_resource_group_file.txt') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_multiple_dependent_tasks(self): - output_file = f'{self.cloud_output_dir}/test_multiple_dependent_tasks.txt' - b = self.batch() - j = b.new_job() - j.command(f'echo "0" >> {j.ofile}') - - for i in range(1, 3): - j2 = b.new_job() - j2.command(f'echo "{i}" > {j2.tmp1}') - j2.command(f'cat {j.ofile} {j2.tmp1} > {j2.ofile}') - j = j2 - - b.write_output(j.ofile, output_file) - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_specify_cpu(self): - b = self.batch() - j = b.new_job() - j.cpu('0.5') - j.command(f'echo "hello" > {j.ofile}') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_specify_memory(self): - b = self.batch() - j = b.new_job() - j.memory('100M') - j.command(f'echo "hello" > {j.ofile}') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_scatter_gather(self): - b = self.batch() - - for i in range(3): - j = b.new_job(name=f'foo{i}') - j.command(f'echo "{i}" > {j.ofile}') - - merger = b.new_job() - merger.command( - 'cat {files} > {ofile}'.format( - files=' '.join( - [j.ofile for j in sorted(b.select_jobs('foo'), key=lambda x: x.name, reverse=True)] # type: ignore - ), - ofile=merger.ofile, - ) - ) - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_file_name_space(self): - b = self.batch() - input = b.read_input(f'{self.cloud_input_dir}/hello (foo) spaces.txt') - j = b.new_job() - j.command(f'cat {input} > {j.ofile}') - b.write_output(j.ofile, f'{self.cloud_output_dir}/hello (foo) spaces.txt') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_dry_run(self): - b = self.batch() - j = b.new_job() - j.command(f'echo hello > {j.ofile}') - b.write_output(j.ofile, f'{self.cloud_output_dir}/test_single_job_output.txt') - b.run(dry_run=True) - - def test_verbose(self): - b = self.batch() - input = b.read_input(f'{self.cloud_input_dir}/hello.txt') - j = b.new_job() - j.command(f'cat {input}') - b.write_output(input, f'{self.cloud_output_dir}/hello.txt') - res = b.run(verbose=True) - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_cloudfuse_fails_with_read_write_mount_option(self): - assert self.bucket - path = f'/{self.bucket}{self.cloud_output_path}' - - b = self.batch() - j = b.new_job() - j.command(f'mkdir -p {path}; echo head > {path}/cloudfuse_test_1') - j.cloudfuse(self.bucket, f'/{self.bucket}', read_only=False) - - try: - b.run() - except ClientResponseError as e: - assert 'Only read-only cloudfuse requests are supported' in e.body, e.body - else: - assert False - - def test_cloudfuse_fails_with_io_mount_point(self): - assert self.bucket - path = f'/{self.bucket}{self.cloud_output_path}' - - b = self.batch() - j = b.new_job() - j.command(f'mkdir -p {path}; echo head > {path}/cloudfuse_test_1') - j.cloudfuse(self.bucket, f'/io', read_only=True) - - try: - b.run() - except ClientResponseError as e: - assert 'Cloudfuse requests with mount_path=/io are not supported' in e.body, e.body - else: - assert False - - def test_cloudfuse_read_only(self): - assert self.bucket - path = f'/{self.bucket}{self.cloud_output_path}' - - b = self.batch() - j = b.new_job() - j.command(f'mkdir -p {path}; echo head > {path}/cloudfuse_test_1') - j.cloudfuse(self.bucket, f'/{self.bucket}', read_only=True) - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - def test_cloudfuse_implicit_dirs(self): - assert self.bucket - path = self.router_fs.parse_url(f'{self.remote_tmpdir}batch-tests/resources/hello.txt').path - b = self.batch() - j = b.new_job() - j.command(f'cat /cloudfuse/{path}') - j.cloudfuse(self.bucket, f'/cloudfuse', read_only=True) - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_cloudfuse_empty_string_bucket_fails(self): - assert self.bucket - b = self.batch() - j = b.new_job() - with self.assertRaises(BatchException): - j.cloudfuse('', '/empty_bucket') - with self.assertRaises(BatchException): - j.cloudfuse(self.bucket, '') - - def test_cloudfuse_submount_in_io_doesnt_rm_bucket(self): - assert self.bucket - b = self.batch() - j = b.new_job() - j.cloudfuse(self.bucket, '/io/cloudfuse') - j.command(f'ls /io/cloudfuse/') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert self.sync_exists(f'{self.remote_tmpdir}batch-tests/resources/hello.txt') - - @skip_in_azure - def test_fuse_requester_pays(self): - assert REQUESTER_PAYS_PROJECT - b = self.batch(requester_pays_project=REQUESTER_PAYS_PROJECT) - j = b.new_job() - j.cloudfuse('hail-test-requester-pays-fds32', '/fuse-bucket') - j.command('cat /fuse-bucket/hello') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - @skip_in_azure - def test_fuse_non_requester_pays_bucket_when_requester_pays_project_specified(self): - assert REQUESTER_PAYS_PROJECT - assert self.bucket - b = self.batch(requester_pays_project=REQUESTER_PAYS_PROJECT) - j = b.new_job() - j.command(f'ls /fuse-bucket') - j.cloudfuse(self.bucket, f'/fuse-bucket', read_only=True) - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - @skip_in_azure - def test_requester_pays(self): - assert REQUESTER_PAYS_PROJECT - b = self.batch(requester_pays_project=REQUESTER_PAYS_PROJECT) - input = b.read_input('gs://hail-test-requester-pays-fds32/hello') - j = b.new_job() - j.command(f'cat {input}') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_benchmark_lookalike_workflow(self): - b = self.batch() - - setup_jobs = [] - for i in range(10): - j = b.new_job(f'setup_{i}').cpu(0.25) - j.command(f'echo "foo" > {j.ofile}') - setup_jobs.append(j) - - jobs = [] - for i in range(500): - j = b.new_job(f'create_file_{i}').cpu(0.25) - j.command(f'echo {setup_jobs[i % len(setup_jobs)].ofile} > {j.ofile}') - j.command(f'echo "bar" >> {j.ofile}') - jobs.append(j) - - combine = b.new_job(f'combine_output').cpu(0.25) - for _ in grouped(arg_max(), jobs): - combine.command(f'cat {" ".join(shq(j.ofile) for j in jobs)} >> {combine.ofile}') - b.write_output(combine.ofile, f'{self.cloud_output_dir}/pipeline_benchmark_test.txt') - # too slow - # assert b.run().status()['state'] == 'success' - - def test_envvar(self): - b = self.batch() - j = b.new_job() - j.env('SOME_VARIABLE', '123abcdef') - j.command('[ $SOME_VARIABLE = "123abcdef" ]') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_job_with_shell(self): - msg = 'hello world' - b = self.batch() - j = b.new_job(shell='/bin/sh') - j.command(f'echo "{msg}"') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_job_with_nonsense_shell(self): - b = self.batch() - j = b.new_job(shell='/bin/ajdsfoijasidojf') - j.command(f'echo "hello"') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - def test_single_job_with_intermediate_failure(self): - b = self.batch() - j = b.new_job() - j.command(f'echoddd "hello"') - j2 = b.new_job() - j2.command(f'echo "world"') - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - def test_input_directory(self): - b = self.batch() - input1 = b.read_input(self.cloud_input_dir) - input2 = b.read_input(self.cloud_input_dir.rstrip('/') + '/') - j = b.new_job() - j.command(f'ls {input1}/hello.txt') - j.command(f'ls {input2}/hello.txt') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_python_job(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - head = b.new_job() - head.command(f'echo "5" > {head.r5}') - head.command(f'echo "3" > {head.r3}') - - def read(path): - with open(path, 'r') as f: - i = f.read() - return int(i) - - def multiply(x, y): - return x * y - - def reformat(x, y): - return {'x': x, 'y': y} - - middle = b.new_python_job() - r3 = middle.call(read, head.r3) - r5 = middle.call(read, head.r5) - r_mult = middle.call(multiply, r3, r5) - - middle2 = b.new_python_job() - r_mult = middle2.call(multiply, r_mult, 2) - r_dict = middle2.call(reformat, r3, r5) - - tail = b.new_job() - tail.command(f'cat {r3.as_str()} {r5.as_repr()} {r_mult.as_str()} {r_dict.as_json()}') - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(4)['main'] == "3\n5\n30\n{\"x\": 3, \"y\": 5}\n", str(res.debug_info()) - - def test_python_job_w_resource_group_unpack_individually(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - head = b.new_job() - head.declare_resource_group(count={'r5': '{root}.r5', 'r3': '{root}.r3'}) - assert isinstance(head.count, ResourceGroup) - - head.command(f'echo "5" > {head.count.r5}') - head.command(f'echo "3" > {head.count.r3}') - - def read(path): - with open(path, 'r') as f: - r = int(f.read()) - return r - - def multiply(x, y): - return x * y - - def reformat(x, y): - return {'x': x, 'y': y} - - middle = b.new_python_job() - r3 = middle.call(read, head.count.r3) - r5 = middle.call(read, head.count.r5) - r_mult = middle.call(multiply, r3, r5) - - middle2 = b.new_python_job() - r_mult = middle2.call(multiply, r_mult, 2) - r_dict = middle2.call(reformat, r3, r5) - - tail = b.new_job() - tail.command(f'cat {r3.as_str()} {r5.as_repr()} {r_mult.as_str()} {r_dict.as_json()}') - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(4)['main'] == "3\n5\n30\n{\"x\": 3, \"y\": 5}\n", str(res.debug_info()) - - def test_python_job_can_write_to_resource_path(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - - def write(path): - with open(path, 'w') as f: - f.write('foo') - - head = b.new_python_job() - head.call(write, head.ofile) - - tail = b.new_bash_job() - tail.command(f'cat {head.ofile}') - - res = b.run() - assert res - assert tail._job_id - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(tail._job_id)['main'] == 'foo', str(res.debug_info()) - - def test_python_job_w_resource_group_unpack_jointly(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - head = b.new_job() - head.declare_resource_group(count={'r5': '{root}.r5', 'r3': '{root}.r3'}) - assert isinstance(head.count, ResourceGroup) - - head.command(f'echo "5" > {head.count.r5}') - head.command(f'echo "3" > {head.count.r3}') - - def read_rg(root): - with open(root['r3'], 'r') as f: - r3 = int(f.read()) - with open(root['r5'], 'r') as f: - r5 = int(f.read()) - return (r3, r5) - - def multiply(r): - x, y = r - return x * y - - middle = b.new_python_job() - r = middle.call(read_rg, head.count) - r_mult = middle.call(multiply, r) - - tail = b.new_job() - tail.command(f'cat {r_mult.as_str()}') - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - job_log_3 = res.get_job_log(3) - assert job_log_3['main'] == "15\n", str((job_log_3, res.debug_info())) - - def test_python_job_w_non_zero_ec(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - j = b.new_python_job() - - def error(): - raise Exception("this should fail") - - j.call(error) - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - def test_python_job_incorrect_signature(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - - def foo(pos_arg1, pos_arg2, *, kwarg1, kwarg2=1): - print(pos_arg1, pos_arg2, kwarg1, kwarg2) - - j = b.new_python_job() - - with pytest.raises(BatchException): - j.call(foo) - with pytest.raises(BatchException): - j.call(foo, 1) - with pytest.raises(BatchException): - j.call(foo, 1, 2) - with pytest.raises(BatchException): - j.call(foo, 1, kwarg1=2) - with pytest.raises(BatchException): - j.call(foo, 1, 2, 3) - with pytest.raises(BatchException): - j.call(foo, 1, 2, kwarg1=3, kwarg2=4, kwarg3=5) - - j.call(foo, 1, 2, kwarg1=3) - j.call(foo, 1, 2, kwarg1=3, kwarg2=4) - - # `print` doesn't have a signature but other builtins like `abs` do - j.call(print, 5) - j.call(abs, -1) - with pytest.raises(BatchException): - j.call(abs, -1, 5) - - def test_fail_fast(self): - b = self.batch(cancel_after_n_failures=1) - - j1 = b.new_job() - j1.command('false') - - j2 = b.new_job() - j2.command('sleep 300') - - res = b.run() - job_status = res.get_job(2).status() - assert job_status['state'] == 'Cancelled', str((job_status, res.debug_info())) - - def test_service_backend_remote_tempdir_with_trailing_slash(self): - backend = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files/') - b = Batch(backend=backend) - j1 = b.new_job() - j1.command(f'echo hello > {j1.ofile}') - j2 = b.new_job() - j2.command(f'cat {j1.ofile}') - b.run() - - def test_service_backend_remote_tempdir_with_no_trailing_slash(self): - backend = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files') - b = Batch(backend=backend) - j1 = b.new_job() - j1.command(f'echo hello > {j1.ofile}') - j2 = b.new_job() - j2.command(f'cat {j1.ofile}') - b.run() - - def test_large_command(self): - backend = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files') - b = Batch(backend=backend) - j1 = b.new_job() - long_str = secrets.token_urlsafe(15 * 1024) - j1.command(f'echo "{long_str}"') - b.run() - - def test_big_batch_which_uses_slow_path(self): - backend = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files') - b = Batch(backend=backend) - # 8 * 256 * 1024 = 2 MiB > 1 MiB max bunch size - for _ in range(8): - j1 = b.new_job() - long_str = secrets.token_urlsafe(256 * 1024) - j1.command(f'echo "{long_str}" > /dev/null') - batch = b.run() - assert not batch._submission_info.used_fast_path - batch_status = batch.status() - assert batch_status['state'] == 'success', str((batch.debug_info())) - - def test_query_on_batch_in_batch(self): - sb = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files') - bb = Batch(backend=sb, default_python_image=HAIL_GENETICS_HAIL_IMAGE) - - tmp_ht_path = self.remote_tmpdir + '/' + secrets.token_urlsafe(32) - - def qob_in_batch(): - import hail as hl - - hl.utils.range_table(10).write(tmp_ht_path, overwrite=True) - - j = bb.new_python_job() - j.env('HAIL_QUERY_BACKEND', 'batch') - j.env('HAIL_BATCH_BILLING_PROJECT', configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, None, '')) - j.env('HAIL_BATCH_REMOTE_TMPDIR', self.remote_tmpdir) - j.call(qob_in_batch) - - bb.run() - - def test_basic_async_fun(self): - backend = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files') - b = Batch(backend=backend) - - j = b.new_python_job() - j.call(asyncio.sleep, 1) - - batch = b.run() - batch_status = batch.status() - assert batch_status['state'] == 'success', str((batch.debug_info())) - - def test_async_fun_returns_value(self): - backend = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files') - b = Batch(backend=backend) - - async def foo(i, j): - await asyncio.sleep(1) - return i * j - - j = b.new_python_job() - result = j.call(foo, 2, 3) - - j = b.new_job() - j.command(f'cat {result.as_str()}') - - batch = b.run() - batch_status = batch.status() - assert batch_status['state'] == 'success', str((batch_status, batch.debug_info())) - job_log_2 = batch.get_job_log(2) - assert job_log_2['main'] == "6\n", str((job_log_2, batch.debug_info())) - - def test_specify_job_region(self): - b = self.batch(cancel_after_n_failures=1) - j = b.new_job('region') - possible_regions = self.backend.supported_regions() - j.regions(possible_regions) - j.command('true') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_always_copy_output(self): - output_path = f'{self.cloud_output_dir}/test_always_copy_output.txt' - - b = self.batch() - j = b.new_job() - j.always_copy_output() - j.command(f'echo "hello" > {j.ofile} && false') - - b.write_output(j.ofile, output_path) - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - b2 = self.batch() - input = b2.read_input(output_path) - file_exists_j = b2.new_job() - file_exists_j.command(f'cat {input}') - - res = b2.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(1)['main'] == "hello\n", str(res.debug_info()) - - def test_no_copy_output_on_failure(self): - output_path = f'{self.cloud_output_dir}/test_no_copy_output.txt' - - b = self.batch() - j = b.new_job() - j.command(f'echo "hello" > {j.ofile} && false') - - b.write_output(j.ofile, output_path) - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - b2 = self.batch() - input = b2.read_input(output_path) - file_exists_j = b2.new_job() - file_exists_j.command(f'cat {input}') - - res = b2.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - def test_update_batch(self): - b = self.batch() - j = b.new_job() - j.command('true') - res = b.run() - - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - j2 = b.new_job() - j2.command('true') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_update_batch_with_dependencies(self): - b = self.batch() - j1 = b.new_job() - j1.command('true') - j2 = b.new_job() - j2.command('false') - res = b.run() - - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - j3 = b.new_job() - j3.command('true') - j3.depends_on(j1) - - j4 = b.new_job() - j4.command('true') - j4.depends_on(j2) - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - assert res.get_job(3).status()['state'] == 'Success', str((res_status, res.debug_info())) - assert res.get_job(4).status()['state'] == 'Cancelled', str((res_status, res.debug_info())) - - def test_update_batch_with_python_job_dependencies(self): - b = self.batch() - - async def foo(i, j): - await asyncio.sleep(1) - return i * j - - j1 = b.new_python_job() - j1.call(foo, 2, 3) - - batch = b.run() - batch_status = batch.status() - assert batch_status['state'] == 'success', str((batch_status, batch.debug_info())) - - j2 = b.new_python_job() - j2.call(foo, 2, 3) - - batch = b.run() - batch_status = batch.status() - assert batch_status['state'] == 'success', str((batch_status, batch.debug_info())) - - j3 = b.new_python_job() - j3.depends_on(j2) - j3.call(foo, 2, 3) - - batch = b.run() - batch_status = batch.status() - assert batch_status['state'] == 'success', str((batch_status, batch.debug_info())) - - def test_update_batch_from_batch_id(self): - b = self.batch() - j = b.new_job() - j.command('true') - res = b.run() - - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - b2 = Batch.from_batch_id(res.id, backend=b._backend) - j2 = b2.new_job() - j2.command('true') - res = b2.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_python_job_with_kwarg(self): - def foo(*, kwarg): - return kwarg - - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - j = b.new_python_job() - r = j.call(foo, kwarg='hello world') - - output_path = f'{self.cloud_output_dir}/test_python_job_with_kwarg' - b.write_output(r.as_json(), output_path) - res = b.run() - assert isinstance(res, bc.Batch) - - assert res.status()['state'] == 'success', str((res, res.debug_info())) - with hfs.open(output_path) as f: - assert orjson.loads(f.read()) == 'hello world' - - def test_tuple_recursive_resource_extraction_in_python_jobs(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - - def write(paths): - if not isinstance(paths, tuple): - raise ValueError('paths must be a tuple') - for i, path in enumerate(paths): - with open(path, 'w') as f: - f.write(f'{i}') - - head = b.new_python_job() - head.call(write, (head.ofile1, head.ofile2)) - - tail = b.new_bash_job() - tail.command(f'cat {head.ofile1}') - tail.command(f'cat {head.ofile2}') - - res = b.run() - assert res - assert tail._job_id - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(tail._job_id)['main'] == '01', str(res.debug_info()) - - def test_list_recursive_resource_extraction_in_python_jobs(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - - def write(paths): - for i, path in enumerate(paths): - with open(path, 'w') as f: - f.write(f'{i}') - - head = b.new_python_job() - head.call(write, [head.ofile1, head.ofile2]) - - tail = b.new_bash_job() - tail.command(f'cat {head.ofile1}') - tail.command(f'cat {head.ofile2}') - - res = b.run() - assert res - assert tail._job_id - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(tail._job_id)['main'] == '01', str(res.debug_info()) - - def test_dict_recursive_resource_extraction_in_python_jobs(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - - def write(kwargs): - for k, v in kwargs.items(): - with open(v, 'w') as f: - f.write(k) - - head = b.new_python_job() - head.call(write, {'a': head.ofile1, 'b': head.ofile2}) - - tail = b.new_bash_job() - tail.command(f'cat {head.ofile1}') - tail.command(f'cat {head.ofile2}') - - res = b.run() - assert res - assert tail._job_id - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(tail._job_id)['main'] == 'ab', str(res.debug_info()) - - def test_wait_on_empty_batch_update(self): - b = self.batch() - b.run(wait=True) - b.run(wait=True) - - def test_non_spot_job(self): - b = self.batch() - j = b.new_job() - j.spot(False) - j.command('echo hello') - res = b.run() - assert res is not None - assert res.get_job(1).status()['spec']['resources']['preemptible'] == False - - def test_spot_unspecified_job(self): - b = self.batch() - j = b.new_job() - j.command('echo hello') - res = b.run() - assert res is not None - assert res.get_job(1).status()['spec']['resources']['preemptible'] == True - - def test_spot_true_job(self): - b = self.batch() - j = b.new_job() - j.spot(True) - j.command('echo hello') - res = b.run() - assert res is not None - assert res.get_job(1).status()['spec']['resources']['preemptible'] == True - - def test_non_spot_batch(self): - b = self.batch(default_spot=False) - j1 = b.new_job() - j1.command('echo hello') - j2 = b.new_job() - j2.command('echo hello') - j3 = b.new_job() - j3.spot(True) - j3.command('echo hello') - res = b.run() - assert res is not None - assert res.get_job(1).status()['spec']['resources']['preemptible'] == False - assert res.get_job(2).status()['spec']['resources']['preemptible'] == False - assert res.get_job(3).status()['spec']['resources']['preemptible'] == True - - def test_local_file_paths_error(self): - b = self.batch() - j = b.new_job() - for input in ["hi.txt", "~/hello.csv", "./hey.tsv", "/sup.json", "file://yo.yaml"]: - with pytest.raises(ValueError) as e: - b.read_input(input) - assert str(e.value).startswith("Local filepath detected") - - @skip_in_azure - def test_validate_cloud_storage_policy(self): - # buckets do not exist (bucket names can't contain the string "google" per - # https://cloud.google.com/storage/docs/buckets) - fake_bucket1 = "google" - fake_bucket2 = "google1" - no_bucket_error = "bucket does not exist" - # bucket exists, but account does not have permissions on it - no_perms_bucket = "test" - no_perms_error = "does not have storage.buckets.get access" - # bucket exists and account has permissions, but is set to use cold storage by default - cold_bucket = "hail-test-cold-storage" - cold_error = "configured to use cold storage by default" - fake_uri1, fake_uri2, no_perms_uri, cold_uri = [ - f"gs://{bucket}/test" for bucket in [fake_bucket1, fake_bucket2, no_perms_bucket, cold_bucket] - ] - - def _test_raises(exception_type, exception_msg, func): - with pytest.raises(exception_type) as e: - func() - assert exception_msg in str(e.value) - - def _test_raises_no_bucket_error(remote_tmpdir, arg=None): - _test_raises( - ClientResponseError, - no_bucket_error, - lambda: ServiceBackend(remote_tmpdir=remote_tmpdir, gcs_bucket_allow_list=arg), - ) - - def _test_raises_cold_error(func): - _test_raises(ValueError, cold_error, func) - - # no configuration, nonexistent buckets error - _test_raises_no_bucket_error(fake_uri1) - _test_raises_no_bucket_error(fake_uri2) - - # no configuration, no perms bucket errors - _test_raises(ClientResponseError, no_perms_error, lambda: ServiceBackend(remote_tmpdir=no_perms_uri)) - - # no configuration, cold bucket errors - _test_raises_cold_error(lambda: ServiceBackend(remote_tmpdir=cold_uri)) - b = self.batch() - _test_raises_cold_error(lambda: b.read_input(cold_uri)) - j = b.new_job() - j.command(f"echo hello > {j.ofile}") - _test_raises_cold_error(lambda: b.write_output(j.ofile, cold_uri)) - - # hailctl config, allowlisted nonexistent buckets don't error - base_config = get_user_config() - local_config = ConfigParser() - local_config.read_dict( - { - **{ - section: {key: val for key, val in base_config[section].items()} - for section in base_config.sections() - }, - **{"gcs": {"bucket_allow_list": f"{fake_bucket1},{fake_bucket2}"}}, - } - ) - - def _get_user_config(): - return local_config - - self.monkeypatch.setattr(user_config, "get_user_config", _get_user_config) - ServiceBackend(remote_tmpdir=fake_uri1) - ServiceBackend(remote_tmpdir=fake_uri2) - - # environment variable config, only allowlisted nonexistent buckets don't error - self.monkeypatch.setenv("HAIL_GCS_BUCKET_ALLOW_LIST", fake_bucket2) - _test_raises_no_bucket_error(fake_uri1) - ServiceBackend(remote_tmpdir=fake_uri2) - - # arg to constructor config, only allowlisted nonexistent buckets don't error - arg = [fake_bucket1] - ServiceBackend(remote_tmpdir=fake_uri1, gcs_bucket_allow_list=arg) - _test_raises_no_bucket_error(fake_uri2, arg) diff --git a/hail/python/test/hailtop/batch/test_batch_local_backend.py b/hail/python/test/hailtop/batch/test_batch_local_backend.py new file mode 100644 index 00000000000..28a9aa18600 --- /dev/null +++ b/hail/python/test/hailtop/batch/test_batch_local_backend.py @@ -0,0 +1,506 @@ +from typing import AsyncIterator +import os +import pytest +import subprocess as sp +import tempfile + +from hailtop import pip_version +from hailtop.batch import Batch, LocalBackend, ResourceGroup +from hailtop.batch.resource import JobResourceFile +from hailtop.batch.utils import concatenate + + +DOCKER_ROOT_IMAGE = os.environ.get('DOCKER_ROOT_IMAGE', 'ubuntu:22.04') +PYTHON_DILL_IMAGE = 'hailgenetics/python-dill:3.9-slim' +HAIL_GENETICS_HAIL_IMAGE = os.environ.get('HAIL_GENETICS_HAIL_IMAGE', f'hailgenetics/hail:{pip_version()}') +REQUESTER_PAYS_PROJECT = os.environ.get('GCS_REQUESTER_PAYS_PROJECT') + + +@pytest.fixture(scope="session") +async def backend() -> AsyncIterator[LocalBackend]: + lb = LocalBackend() + try: + yield lb + finally: + await lb.async_close() + + +@pytest.fixture +def batch(backend, requester_pays_project=None): + return Batch(backend=backend, requester_pays_project=requester_pays_project) + + +def test_read_input_and_write_output(batch): + with tempfile.NamedTemporaryFile('w') as input_file, tempfile.NamedTemporaryFile('w') as output_file: + input_file.write('abc') + input_file.flush() + + b = batch + input = b.read_input(input_file.name) + b.write_output(input, output_file.name) + b.run() + + assert open(input_file.name).read() == open(output_file.name).read() + + +def test_read_input_group(batch): + with tempfile.NamedTemporaryFile('w') as input_file1, tempfile.NamedTemporaryFile( + 'w' + ) as input_file2, tempfile.NamedTemporaryFile('w') as output_file1, tempfile.NamedTemporaryFile( + 'w' + ) as output_file2: + input_file1.write('abc') + input_file2.write('123') + input_file1.flush() + input_file2.flush() + + b = batch + input = b.read_input_group(in1=input_file1.name, in2=input_file2.name) + + b.write_output(input.in1, output_file1.name) + b.write_output(input.in2, output_file2.name) + b.run() + + assert open(input_file1.name).read() == open(output_file1.name).read() + assert open(input_file2.name).read() == open(output_file2.name).read() + + +def test_write_resource_group(batch): + with tempfile.NamedTemporaryFile('w') as input_file1, tempfile.NamedTemporaryFile( + 'w' + ) as input_file2, tempfile.TemporaryDirectory() as output_dir: + b = batch + input = b.read_input_group(in1=input_file1.name, in2=input_file2.name) + + b.write_output(input, output_dir + '/foo') + b.run() + + assert open(input_file1.name).read() == open(output_dir + '/foo.in1').read() + assert open(input_file2.name).read() == open(output_dir + '/foo.in2').read() + + +def test_single_job(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + msg = 'hello world' + + b = batch + j = b.new_job() + j.command(f'printf "{msg}" > {j.ofile}') + b.write_output(j.ofile, output_file.name) + b.run() + + assert open(output_file.name).read() == msg + + +def test_single_job_with_shell(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + msg = 'hello world' + + b = batch + j = b.new_job(shell='/bin/bash') + j.command(f'printf "{msg}" > {j.ofile}') + + b.write_output(j.ofile, output_file.name) + b.run() + + assert open(output_file.name).read() == msg + + +def test_single_job_with_nonsense_shell(batch): + b = batch + j = b.new_job(shell='/bin/ajdsfoijasidojf') + j.image(DOCKER_ROOT_IMAGE) + j.command(f'printf "hello"') + with pytest.raises(Exception): + b.run() + + b = batch + j = b.new_job(shell='/bin/nonexistent') + j.command(f'printf "hello"') + with pytest.raises(Exception): + b.run() + + +def test_single_job_with_intermediate_failure(batch): + b = batch + j = b.new_job() + j.command(f'echoddd "hello"') + j2 = b.new_job() + j2.command(f'echo "world"') + + with pytest.raises(Exception): + b.run() + + +def test_single_job_w_input(batch): + with tempfile.NamedTemporaryFile('w') as input_file, tempfile.NamedTemporaryFile('w') as output_file: + msg = 'abc' + input_file.write(msg) + input_file.flush() + + b = batch + input = b.read_input(input_file.name) + j = b.new_job() + j.command(f'cat {input} > {j.ofile}') + b.write_output(j.ofile, output_file.name) + b.run() + + assert open(output_file.name).read() == msg + + +def test_single_job_w_input_group(batch): + with tempfile.NamedTemporaryFile('w') as input_file1, tempfile.NamedTemporaryFile( + 'w' + ) as input_file2, tempfile.NamedTemporaryFile('w') as output_file: + msg1 = 'abc' + msg2 = '123' + + input_file1.write(msg1) + input_file2.write(msg2) + input_file1.flush() + input_file2.flush() + + b = batch + input = b.read_input_group(in1=input_file1.name, in2=input_file2.name) + j = b.new_job() + j.command(f'cat {input.in1} {input.in2} > {j.ofile}') + j.command(f'cat {input}.in1 {input}.in2') + b.write_output(j.ofile, output_file.name) + b.run() + + assert open(output_file.name).read() == msg1 + msg2 + + +def test_single_job_bad_command(batch): + b = batch + j = b.new_job() + j.command("foo") # this should fail! + with pytest.raises(sp.CalledProcessError): + b.run() + + +def test_declare_resource_group(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + msg = 'hello world' + b = batch + j = b.new_job() + j.declare_resource_group(ofile={'log': "{root}.txt"}) + assert isinstance(j.ofile, ResourceGroup) + j.command(f'printf "{msg}" > {j.ofile.log}') + b.write_output(j.ofile.log, output_file.name) + b.run() + + assert open(output_file.name).read() == msg + + +def test_resource_group_get_all_inputs(batch): + b = batch + input = b.read_input_group(fasta="foo", idx="bar") + j = b.new_job() + j.command(f"cat {input.fasta}") + assert input.fasta in j._inputs + assert input.idx in j._inputs + + +def test_resource_group_get_all_mentioned(batch): + b = batch + j = b.new_job() + j.declare_resource_group(foo={'bed': '{root}.bed', 'bim': '{root}.bim'}) + assert isinstance(j.foo, ResourceGroup) + j.command(f"cat {j.foo.bed}") + assert j.foo.bed in j._mentioned + assert j.foo.bim not in j._mentioned + + +def test_resource_group_get_all_mentioned_dependent_jobs(batch): + b = batch + j = b.new_job() + j.declare_resource_group(foo={'bed': '{root}.bed', 'bim': '{root}.bim'}) + j.command(f"cat") + j2 = b.new_job() + j2.command(f"cat {j.foo}") + + +def test_resource_group_get_all_outputs(batch): + b = batch + j1 = b.new_job() + j1.declare_resource_group(foo={'bed': '{root}.bed', 'bim': '{root}.bim'}) + assert isinstance(j1.foo, ResourceGroup) + j1.command(f"cat {j1.foo.bed}") + j2 = b.new_job() + j2.command(f"cat {j1.foo.bed}") + + for r in [j1.foo.bed, j1.foo.bim]: + assert r in j1._internal_outputs + assert r in j2._inputs + + assert j1.foo.bed in j1._mentioned + assert j1.foo.bim not in j1._mentioned + + assert j1.foo.bed in j2._mentioned + assert j1.foo.bim not in j2._mentioned + + assert j1.foo not in j1._mentioned + + +def test_multiple_isolated_jobs(batch): + b = batch + + output_files = [] + try: + output_files = [tempfile.NamedTemporaryFile('w') for _ in range(5)] + + for i, ofile in enumerate(output_files): + msg = f'hello world {i}' + j = b.new_job() + j.command(f'printf "{msg}" > {j.ofile}') + b.write_output(j.ofile, ofile.name) + b.run() + + for i, ofile in enumerate(output_files): + msg = f'hello world {i}' + assert open(ofile.name).read() == msg + finally: + [ofile.close() for ofile in output_files] + + +def test_multiple_dependent_jobs(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + j = b.new_job() + j.command(f'echo "0" >> {j.ofile}') + + for i in range(1, 3): + j2 = b.new_job() + j2.command(f'echo "{i}" > {j2.tmp1}') + j2.command(f'cat {j.ofile} {j2.tmp1} > {j2.ofile}') + j = j2 + + b.write_output(j.ofile, output_file.name) + b.run() + + assert open(output_file.name).read() == "0\n1\n2\n" + + +def test_select_jobs(batch): + b = batch + for i in range(3): + b.new_job(name=f'foo{i}') + assert len(b.select_jobs('foo')) == 3 + + +def test_scatter_gather(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + + for i in range(3): + j = b.new_job(name=f'foo{i}') + j.command(f'echo "{i}" > {j.ofile}') + + merger = b.new_job() + merger.command( + 'cat {files} > {ofile}'.format( + files=' '.join( + [j.ofile for j in sorted(b.select_jobs('foo'), key=lambda x: x.name, reverse=True)] # type: ignore + ), + ofile=merger.ofile, + ) + ) + + b.write_output(merger.ofile, output_file.name) + b.run() + + assert open(output_file.name).read() == '2\n1\n0\n' + + +def test_add_extension_job_resource_file(batch): + b = batch + j = b.new_job() + j.command(f'echo "hello" > {j.ofile}') + assert isinstance(j.ofile, JobResourceFile) + j.ofile.add_extension('.txt.bgz') + assert j.ofile._value + assert j.ofile._value.endswith('.txt.bgz') + + +def test_add_extension_input_resource_file(batch): + input_file1 = '/tmp/data/example1.txt.bgz.foo' + b = batch + in1 = b.read_input(input_file1) + assert in1._value + assert in1._value.endswith('.txt.bgz.foo') + + +def test_file_name_space(batch): + with tempfile.NamedTemporaryFile( + 'w', prefix="some file name with (foo) spaces" + ) as input_file, tempfile.NamedTemporaryFile('w', prefix="another file name with (foo) spaces") as output_file: + input_file.write('abc') + input_file.flush() + + b = batch + input = b.read_input(input_file.name) + j = b.new_job() + j.command(f'cat {input} > {j.ofile}') + b.write_output(j.ofile, output_file.name) + b.run() + + assert open(input_file.name).read() == open(output_file.name).read() + + +def test_resource_group_mentioned(batch): + b = batch + j = b.new_job() + j.declare_resource_group(foo={'bed': '{root}.bed'}) + assert isinstance(j.foo, ResourceGroup) + j.command(f'echo "hello" > {j.foo}') + + t2 = b.new_job() + t2.command(f'echo "hello" >> {j.foo.bed}') + b.run() + + +def test_envvar(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + j = b.new_job() + j.env('SOME_VARIABLE', '123abcdef') + j.command(f'printf $SOME_VARIABLE > {j.ofile}') + b.write_output(j.ofile, output_file.name) + b.run() + assert open(output_file.name).read() == '123abcdef' + + +def test_concatenate(batch): + b = batch + files = [] + for _ in range(10): + j = b.new_job() + j.command(f'touch {j.ofile}') + files.append(j.ofile) + concatenate(b, files, branching_factor=2) + assert len(b._jobs) == 10 + (5 + 3 + 2 + 1) + b.run() + + +def test_python_job(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + head = b.new_job() + head.command(f'echo "5" > {head.r5}') + head.command(f'echo "3" > {head.r3}') + + def read(path): + with open(path, 'r') as f: + i = f.read() + return int(i) + + def multiply(x, y): + return x * y + + def reformat(x, y): + return {'x': x, 'y': y} + + middle = b.new_python_job() + r3 = middle.call(read, head.r3) + r5 = middle.call(read, head.r5) + r_mult = middle.call(multiply, r3, r5) + + middle2 = b.new_python_job() + r_mult = middle2.call(multiply, r_mult, 2) + r_dict = middle2.call(reformat, r3, r5) + + tail = b.new_job() + tail.command(f'cat {r3.as_str()} {r5.as_repr()} {r_mult.as_str()} {r_dict.as_json()} > {tail.ofile}') + + b.write_output(tail.ofile, output_file.name) + b.run() + assert open(output_file.name).read() == '3\n5\n30\n{\"x\": 3, \"y\": 5}\n' + + +def test_backend_context_manager(): + with LocalBackend() as backend: + b = Batch(backend=backend) + b.run() + + +def test_failed_jobs_dont_stop_non_dependent_jobs(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + + head = b.new_job() + head.command(f'printf 1 > {head.ofile}') + + head2 = b.new_job() + head2.command('false') + + tail = b.new_job() + tail.command(f'cat {head.ofile} > {tail.ofile}') + b.write_output(tail.ofile, output_file.name) + with pytest.raises(Exception): + b.run() + assert open(output_file.name).read() == '1' + + +def test_failed_jobs_stop_child_jobs(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + + head = b.new_job() + head.command(f'printf 1 > {head.ofile}') + head.command('false') + + head2 = b.new_job() + head2.command(f'printf 2 > {head2.ofile}') + + tail = b.new_job() + tail.command(f'cat {head.ofile} > {tail.ofile}') + + b.write_output(head2.ofile, output_file.name) + b.write_output(tail.ofile, output_file.name) + with pytest.raises(Exception): + b.run() + assert open(output_file.name).read() == '2' + + +def test_failed_jobs_stop_grandchild_jobs(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + + head = b.new_job() + head.command(f'printf 1 > {head.ofile}') + head.command('false') + + head2 = b.new_job() + head2.command(f'printf 2 > {head2.ofile}') + + tail = b.new_job() + tail.command(f'cat {head.ofile} > {tail.ofile}') + + tail2 = b.new_job() + tail2.depends_on(tail) + tail2.command(f'printf foo > {tail2.ofile}') + + b.write_output(head2.ofile, output_file.name) + b.write_output(tail2.ofile, output_file.name) + with pytest.raises(Exception): + b.run() + assert open(output_file.name).read() == '2' + + +def test_failed_jobs_dont_stop_always_run_jobs(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + + head = b.new_job() + head.command(f'printf 1 > {head.ofile}') + head.command('false') + + tail = b.new_job() + tail.command(f'cat {head.ofile} > {tail.ofile}') + tail.always_run() + + b.write_output(tail.ofile, output_file.name) + with pytest.raises(Exception): + b.run() + assert open(output_file.name).read() == '1' diff --git a/hail/python/test/hailtop/batch/test_batch_service_backend.py b/hail/python/test/hailtop/batch/test_batch_service_backend.py new file mode 100644 index 00000000000..2364dba9a4e --- /dev/null +++ b/hail/python/test/hailtop/batch/test_batch_service_backend.py @@ -0,0 +1,1181 @@ +from typing import AsyncIterator, List, Tuple +import asyncio +import inspect +import secrets + +import pytest +import os +from shlex import quote as shq +import uuid +import re +import orjson + +import hailtop.fs as hfs +import hailtop.batch_client.client as bc +from hailtop import pip_version +from hailtop.batch import Batch, ServiceBackend, ResourceGroup +from hailtop.batch.exceptions import BatchException +from hailtop.batch.globals import arg_max +from hailtop.utils import grouped, async_to_blocking, secret_alnum_string +from hailtop.config import get_remote_tmpdir, configuration_of +from hailtop.aiotools.router_fs import RouterAsyncFS +from hailtop.test_utils import skip_in_azure +from hailtop.httpx import ClientResponseError + +from configparser import ConfigParser +from hailtop.config import get_user_config, user_config +from hailtop.config.variables import ConfigVariable + + +DOCKER_ROOT_IMAGE = os.environ.get('DOCKER_ROOT_IMAGE', 'ubuntu:22.04') +PYTHON_DILL_IMAGE = 'hailgenetics/python-dill:3.9-slim' +HAIL_GENETICS_HAIL_IMAGE = os.environ.get('HAIL_GENETICS_HAIL_IMAGE', f'hailgenetics/hail:{pip_version()}') +REQUESTER_PAYS_PROJECT = os.environ.get('GCS_REQUESTER_PAYS_PROJECT') + + +@pytest.fixture(scope="session") +async def backend() -> AsyncIterator[ServiceBackend]: + sb = ServiceBackend() + try: + yield sb + finally: + await sb.async_close() + + +@pytest.fixture(scope="session") +async def fs() -> AsyncIterator[RouterAsyncFS]: + fs = RouterAsyncFS() + try: + yield fs + finally: + await fs.close() + + +@pytest.fixture(scope="session") +def tmpdir() -> str: + return os.path.join( + get_remote_tmpdir('test_batch_service_backend.py::tmpdir'), + secret_alnum_string(5), # create a unique URL for each split of the tests + ) + + +@pytest.fixture +def output_tmpdir(tmpdir: str) -> str: + return os.path.join(tmpdir, 'output', secret_alnum_string(5)) + + +@pytest.fixture +def output_bucket_path(fs: RouterAsyncFS, output_tmpdir: str) -> Tuple[str, str, str]: + url = fs.parse_url(output_tmpdir) + bucket = '/'.join(url.bucket_parts) + path = url.path + path = '/' + os.path.join(bucket, path) + return bucket, path, output_tmpdir + + +@pytest.fixture(scope="session") +async def upload_test_files( + fs: RouterAsyncFS, tmpdir: str +) -> Tuple[Tuple[str, bytes], Tuple[str, bytes], Tuple[str, bytes]]: + test_files = ( + (os.path.join(tmpdir, 'inputs/hello.txt'), b'hello world'), + (os.path.join(tmpdir, 'inputs/hello spaces.txt'), b'hello'), + (os.path.join(tmpdir, 'inputs/hello (foo) spaces.txt'), b'hello'), + ) + await asyncio.gather(*(fs.write(url, data) for url, data in test_files)) + return test_files + + +def batch(backend, **kwargs): + name_of_test_method = inspect.stack()[1][3] + return Batch( + name=name_of_test_method, + backend=backend, + default_image=DOCKER_ROOT_IMAGE, + attributes={'foo': 'a', 'bar': 'b'}, + **kwargs, + ) + + +def test_single_task_no_io(backend: ServiceBackend): + b = batch(backend) + j = b.new_job() + j.command('echo hello') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_task_input( + backend: ServiceBackend, upload_test_files: Tuple[Tuple[str, bytes], Tuple[str, bytes], Tuple[str, bytes]] +): + (url1, data1), _, _ = upload_test_files + b = batch(backend) + input = b.read_input(url1) + j = b.new_job() + j.command(f'cat {input}') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_task_input_resource_group( + backend: ServiceBackend, upload_test_files: Tuple[Tuple[str, bytes], Tuple[str, bytes], Tuple[str, bytes]] +): + (url1, data1), _, _ = upload_test_files + b = batch(backend) + input = b.read_input_group(foo=url1) + j = b.new_job() + j.storage('10Gi') + j.command(f'cat {input.foo}') + j.command(f'cat {input}.foo') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_task_output(backend: ServiceBackend): + b = batch(backend) + j = b.new_job(attributes={'a': 'bar', 'b': 'foo'}) + j.command(f'echo hello > {j.ofile}') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_task_write_output(backend: ServiceBackend, output_tmpdir: str): + b = batch(backend) + j = b.new_job() + j.command(f'echo hello > {j.ofile}') + b.write_output(j.ofile, os.path.join(output_tmpdir, 'test_single_task_output.txt')) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_task_resource_group(backend: ServiceBackend): + b = batch(backend) + j = b.new_job() + j.declare_resource_group(output={'foo': '{root}.foo'}) + assert isinstance(j.output, ResourceGroup) + j.command(f'echo "hello" > {j.output.foo}') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_task_write_resource_group(backend: ServiceBackend, output_tmpdir: str): + b = batch(backend) + j = b.new_job() + j.declare_resource_group(output={'foo': '{root}.foo'}) + assert isinstance(j.output, ResourceGroup) + j.command(f'echo "hello" > {j.output.foo}') + b.write_output(j.output, os.path.join(output_tmpdir, 'test_single_task_write_resource_group')) + b.write_output(j.output.foo, os.path.join(output_tmpdir, 'test_single_task_write_resource_group_file.txt')) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_multiple_dependent_tasks(backend: ServiceBackend, output_tmpdir: str): + output_file = os.path.join(output_tmpdir, 'test_multiple_dependent_tasks.txt') + b = batch(backend) + j = b.new_job() + j.command(f'echo "0" >> {j.ofile}') + + for i in range(1, 3): + j2 = b.new_job() + j2.command(f'echo "{i}" > {j2.tmp1}') + j2.command(f'cat {j.ofile} {j2.tmp1} > {j2.ofile}') + j = j2 + + b.write_output(j.ofile, output_file) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_specify_cpu(backend: ServiceBackend): + b = batch(backend) + j = b.new_job() + j.cpu('0.5') + j.command(f'echo "hello" > {j.ofile}') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_specify_memory(backend: ServiceBackend): + b = batch(backend) + j = b.new_job() + j.memory('100M') + j.command(f'echo "hello" > {j.ofile}') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_scatter_gather(backend: ServiceBackend): + b = batch(backend) + + for i in range(3): + j = b.new_job(name=f'foo{i}') + j.command(f'echo "{i}" > {j.ofile}') + + merger = b.new_job() + merger.command( + 'cat {files} > {ofile}'.format( + files=' '.join( + [j.ofile for j in sorted(b.select_jobs('foo'), key=lambda x: x.name, reverse=True)] # type: ignore + ), + ofile=merger.ofile, + ) + ) + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_file_name_space( + backend: ServiceBackend, + upload_test_files: Tuple[Tuple[str, bytes], Tuple[str, bytes], Tuple[str, bytes]], + output_tmpdir: str, +): + _, _, (url3, data3) = upload_test_files + b = batch(backend) + input = b.read_input(url3) + j = b.new_job() + j.command(f'cat {input} > {j.ofile}') + b.write_output(j.ofile, os.path.join(output_tmpdir, 'hello (foo) spaces.txt')) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_dry_run(backend: ServiceBackend, output_tmpdir: str): + b = batch(backend) + j = b.new_job() + j.command(f'echo hello > {j.ofile}') + b.write_output(j.ofile, os.path.join(output_tmpdir, 'test_single_job_output.txt')) + b.run(dry_run=True) + + +def test_verbose( + backend: ServiceBackend, + upload_test_files: Tuple[Tuple[str, bytes], Tuple[str, bytes], Tuple[str, bytes]], + output_tmpdir: str, +): + (url1, data1), _, _ = upload_test_files + b = batch(backend) + input = b.read_input(url1) + j = b.new_job() + j.command(f'cat {input}') + b.write_output(input, os.path.join(output_tmpdir, 'hello.txt')) + res = b.run(verbose=True) + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_cloudfuse_fails_with_read_write_mount_option(fs: RouterAsyncFS, backend: ServiceBackend, output_bucket_path): + bucket, path, output_tmpdir = output_bucket_path + + b = batch(backend) + j = b.new_job() + j.command(f'mkdir -p {path}; echo head > {path}/cloudfuse_test_1') + j.cloudfuse(bucket, f'/{bucket}', read_only=False) + + try: + b.run() + except ClientResponseError as e: + assert 'Only read-only cloudfuse requests are supported' in e.body, e.body + else: + assert False + + +def test_cloudfuse_fails_with_io_mount_point(fs: RouterAsyncFS, backend: ServiceBackend, output_bucket_path): + bucket, path, output_tmpdir = output_bucket_path + + b = batch(backend) + j = b.new_job() + j.command(f'mkdir -p {path}; echo head > {path}/cloudfuse_test_1') + j.cloudfuse(bucket, f'/io', read_only=True) + + try: + b.run() + except ClientResponseError as e: + assert 'Cloudfuse requests with mount_path=/io are not supported' in e.body, e.body + else: + assert False + + +def test_cloudfuse_read_only(backend: ServiceBackend, output_bucket_path): + bucket, path, output_tmpdir = output_bucket_path + + b = batch(backend) + j = b.new_job() + j.command(f'mkdir -p {path}; echo head > {path}/cloudfuse_test_1') + j.cloudfuse(bucket, f'/{bucket}', read_only=True) + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + +def test_cloudfuse_implicit_dirs(fs: RouterAsyncFS, backend: ServiceBackend, upload_test_files): + (url1, data1), _, _ = upload_test_files + parsed_url1 = fs.parse_url(url1) + object_name = parsed_url1.path + bucket_name = '/'.join(parsed_url1.bucket_parts) + + b = batch(backend) + j = b.new_job() + j.command(f'cat ' + os.path.join('/cloudfuse', object_name)) + j.cloudfuse(bucket_name, f'/cloudfuse', read_only=True) + + res = b.run() + assert res + res_status = res.status() + assert res.get_job_log(1)['main'] == data1.decode() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_cloudfuse_empty_string_bucket_fails(backend: ServiceBackend, output_bucket_path): + bucket, path, output_tmpdir = output_bucket_path + + b = batch(backend) + j = b.new_job() + with pytest.raises(BatchException): + j.cloudfuse('', '/empty_bucket') + with pytest.raises(BatchException): + j.cloudfuse(bucket, '') + + +async def test_cloudfuse_submount_in_io_doesnt_rm_bucket( + fs: RouterAsyncFS, backend: ServiceBackend, output_bucket_path +): + bucket, path, output_tmpdir = output_bucket_path + + should_still_exist_url = os.path.join(output_tmpdir, 'should-still-exist') + await fs.write(should_still_exist_url, b'should-still-exist') + + b = batch(backend) + j = b.new_job() + j.cloudfuse(bucket, '/io/cloudfuse') + j.command(f'ls /io/cloudfuse/') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert await fs.read(should_still_exist_url) == b'should-still-exist' + + +@skip_in_azure +def test_fuse_requester_pays(backend: ServiceBackend): + assert REQUESTER_PAYS_PROJECT + b = batch(backend, requester_pays_project=REQUESTER_PAYS_PROJECT) + j = b.new_job() + j.cloudfuse('hail-test-requester-pays-fds32', '/fuse-bucket') + j.command('cat /fuse-bucket/hello') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +@skip_in_azure +def test_fuse_non_requester_pays_bucket_when_requester_pays_project_specified( + backend: ServiceBackend, output_bucket_path +): + bucket, path, output_tmpdir = output_bucket_path + assert REQUESTER_PAYS_PROJECT + + b = batch(backend, requester_pays_project=REQUESTER_PAYS_PROJECT) + j = b.new_job() + j.command(f'ls /fuse-bucket') + j.cloudfuse(bucket, f'/fuse-bucket', read_only=True) + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +@skip_in_azure +def test_requester_pays(backend: ServiceBackend): + assert REQUESTER_PAYS_PROJECT + b = batch(backend, requester_pays_project=REQUESTER_PAYS_PROJECT) + input = b.read_input('gs://hail-test-requester-pays-fds32/hello') + j = b.new_job() + j.command(f'cat {input}') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_benchmark_lookalike_workflow(backend: ServiceBackend, output_tmpdir): + b = batch(backend) + + setup_jobs = [] + for i in range(10): + j = b.new_job(f'setup_{i}').cpu(0.25) + j.command(f'echo "foo" > {j.ofile}') + setup_jobs.append(j) + + jobs = [] + for i in range(500): + j = b.new_job(f'create_file_{i}').cpu(0.25) + j.command(f'echo {setup_jobs[i % len(setup_jobs)].ofile} > {j.ofile}') + j.command(f'echo "bar" >> {j.ofile}') + jobs.append(j) + + combine = b.new_job(f'combine_output').cpu(0.25) + for _ in grouped(arg_max(), jobs): + combine.command(f'cat {" ".join(shq(j.ofile) for j in jobs)} >> {combine.ofile}') + b.write_output(combine.ofile, os.path.join(output_tmpdir, 'pipeline_benchmark_test.txt')) + # too slow + # assert b.run().status()['state'] == 'success' + + +def test_envvar(backend: ServiceBackend): + b = batch(backend) + j = b.new_job() + j.env('SOME_VARIABLE', '123abcdef') + j.command('[ $SOME_VARIABLE = "123abcdef" ]') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_job_with_shell(backend: ServiceBackend): + msg = 'hello world' + b = batch(backend) + j = b.new_job(shell='/bin/sh') + j.command(f'echo "{msg}"') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_job_with_nonsense_shell(backend: ServiceBackend): + b = batch(backend) + j = b.new_job(shell='/bin/ajdsfoijasidojf') + j.command(f'echo "hello"') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + +def test_single_job_with_intermediate_failure(backend: ServiceBackend): + b = batch(backend) + j = b.new_job() + j.command(f'echoddd "hello"') + j2 = b.new_job() + j2.command(f'echo "world"') + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + +def test_input_directory( + backend: ServiceBackend, upload_test_files: Tuple[Tuple[str, bytes], Tuple[str, bytes], Tuple[str, bytes]] +): + (url1, data1), _, _ = upload_test_files + b = batch(backend) + containing_folder = '/'.join(url1.rstrip('/').split('/')[:-1]) + input1 = b.read_input(containing_folder) + input2 = b.read_input(containing_folder + '/') + j = b.new_job() + j.command(f'ls {input1}/hello.txt') + j.command(f'ls {input2}/hello.txt') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_python_job(backend: ServiceBackend): + b = batch(backend, default_python_image=PYTHON_DILL_IMAGE) + head = b.new_job() + head.command(f'echo "5" > {head.r5}') + head.command(f'echo "3" > {head.r3}') + + def read(path): + with open(path, 'r') as f: + i = f.read() + return int(i) + + def multiply(x, y): + return x * y + + def reformat(x, y): + return {'x': x, 'y': y} + + middle = b.new_python_job() + r3 = middle.call(read, head.r3) + r5 = middle.call(read, head.r5) + r_mult = middle.call(multiply, r3, r5) + + middle2 = b.new_python_job() + r_mult = middle2.call(multiply, r_mult, 2) + r_dict = middle2.call(reformat, r3, r5) + + tail = b.new_job() + tail.command(f'cat {r3.as_str()} {r5.as_repr()} {r_mult.as_str()} {r_dict.as_json()}') + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(4)['main'] == "3\n5\n30\n{\"x\": 3, \"y\": 5}\n", str(res.debug_info()) + + +def test_python_job_w_resource_group_unpack_individually(backend: ServiceBackend): + b = batch(backend, default_python_image=PYTHON_DILL_IMAGE) + head = b.new_job() + head.declare_resource_group(count={'r5': '{root}.r5', 'r3': '{root}.r3'}) + assert isinstance(head.count, ResourceGroup) + + head.command(f'echo "5" > {head.count.r5}') + head.command(f'echo "3" > {head.count.r3}') + + def read(path): + with open(path, 'r') as f: + r = int(f.read()) + return r + + def multiply(x, y): + return x * y + + def reformat(x, y): + return {'x': x, 'y': y} + + middle = b.new_python_job() + r3 = middle.call(read, head.count.r3) + r5 = middle.call(read, head.count.r5) + r_mult = middle.call(multiply, r3, r5) + + middle2 = b.new_python_job() + r_mult = middle2.call(multiply, r_mult, 2) + r_dict = middle2.call(reformat, r3, r5) + + tail = b.new_job() + tail.command(f'cat {r3.as_str()} {r5.as_repr()} {r_mult.as_str()} {r_dict.as_json()}') + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(4)['main'] == "3\n5\n30\n{\"x\": 3, \"y\": 5}\n", str(res.debug_info()) + + +def test_python_job_can_write_to_resource_path(backend: ServiceBackend): + b = batch(backend, default_python_image=PYTHON_DILL_IMAGE) + + def write(path): + with open(path, 'w') as f: + f.write('foo') + + head = b.new_python_job() + head.call(write, head.ofile) + + tail = b.new_bash_job() + tail.command(f'cat {head.ofile}') + + res = b.run() + assert res + assert tail._job_id + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(tail._job_id)['main'] == 'foo', str(res.debug_info()) + + +def test_python_job_w_resource_group_unpack_jointly(backend: ServiceBackend): + b = batch(backend, default_python_image=PYTHON_DILL_IMAGE) + head = b.new_job() + head.declare_resource_group(count={'r5': '{root}.r5', 'r3': '{root}.r3'}) + assert isinstance(head.count, ResourceGroup) + + head.command(f'echo "5" > {head.count.r5}') + head.command(f'echo "3" > {head.count.r3}') + + def read_rg(root): + with open(root['r3'], 'r') as f: + r3 = int(f.read()) + with open(root['r5'], 'r') as f: + r5 = int(f.read()) + return (r3, r5) + + def multiply(r): + x, y = r + return x * y + + middle = b.new_python_job() + r = middle.call(read_rg, head.count) + r_mult = middle.call(multiply, r) + + tail = b.new_job() + tail.command(f'cat {r_mult.as_str()}') + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + job_log_3 = res.get_job_log(3) + assert job_log_3['main'] == "15\n", str((job_log_3, res.debug_info())) + + +def test_python_job_w_non_zero_ec(backend: ServiceBackend): + b = batch(backend, default_python_image=PYTHON_DILL_IMAGE) + j = b.new_python_job() + + def error(): + raise Exception("this should fail") + + j.call(error) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + +def test_python_job_incorrect_signature(backend: ServiceBackend): + b = batch(backend, default_python_image=PYTHON_DILL_IMAGE) + + def foo(pos_arg1, pos_arg2, *, kwarg1, kwarg2=1): + print(pos_arg1, pos_arg2, kwarg1, kwarg2) + + j = b.new_python_job() + + with pytest.raises(BatchException): + j.call(foo) + with pytest.raises(BatchException): + j.call(foo, 1) + with pytest.raises(BatchException): + j.call(foo, 1, 2) + with pytest.raises(BatchException): + j.call(foo, 1, kwarg1=2) + with pytest.raises(BatchException): + j.call(foo, 1, 2, 3) + with pytest.raises(BatchException): + j.call(foo, 1, 2, kwarg1=3, kwarg2=4, kwarg3=5) + + j.call(foo, 1, 2, kwarg1=3) + j.call(foo, 1, 2, kwarg1=3, kwarg2=4) + + # `print` doesn't have a signature but other builtins like `abs` do + j.call(print, 5) + j.call(abs, -1) + with pytest.raises(BatchException): + j.call(abs, -1, 5) + + +def test_fail_fast(backend: ServiceBackend): + b = batch(backend, cancel_after_n_failures=1) + + j1 = b.new_job() + j1.command('false') + + j2 = b.new_job() + j2.command('sleep 300') + + res = b.run() + assert res + job_status = res.get_job(2).status() + assert job_status['state'] == 'Cancelled', str((job_status, res.debug_info())) + + +def test_service_backend_remote_tempdir_with_trailing_slash(backend): + b = Batch(backend=backend) + j1 = b.new_job() + j1.command(f'echo hello > {j1.ofile}') + j2 = b.new_job() + j2.command(f'cat {j1.ofile}') + b.run() + + +def test_service_backend_remote_tempdir_with_no_trailing_slash(backend): + b = Batch(backend=backend) + j1 = b.new_job() + j1.command(f'echo hello > {j1.ofile}') + j2 = b.new_job() + j2.command(f'cat {j1.ofile}') + b.run() + + +def test_large_command(backend: ServiceBackend): + b = Batch(backend=backend) + j1 = b.new_job() + long_str = secrets.token_urlsafe(15 * 1024) + j1.command(f'echo "{long_str}"') + b.run() + + +def test_big_batch_which_uses_slow_path(backend: ServiceBackend): + b = Batch(backend=backend) + # 8 * 256 * 1024 = 2 MiB > 1 MiB max bunch size + for _ in range(8): + j1 = b.new_job() + long_str = secrets.token_urlsafe(256 * 1024) + j1.command(f'echo "{long_str}" > /dev/null') + res = b.run() + assert res + assert not res._submission_info.used_fast_path + batch_status = res.status() + assert batch_status['state'] == 'success', str((res.debug_info())) + + +def test_query_on_batch_in_batch(backend: ServiceBackend, output_tmpdir: str): + bb = Batch(backend=backend, default_python_image=HAIL_GENETICS_HAIL_IMAGE) + + tmp_ht_path = os.path.join(output_tmpdir, secrets.token_urlsafe(32)) + + def qob_in_batch(): + import hail as hl + + hl.utils.range_table(10).write(tmp_ht_path, overwrite=True) + + j = bb.new_python_job() + j.env('HAIL_QUERY_BACKEND', 'batch') + j.env('HAIL_BATCH_BILLING_PROJECT', configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, None, '')) + j.env('HAIL_BATCH_REMOTE_TMPDIR', output_tmpdir) + j.call(qob_in_batch) + + bb.run() + + +def test_basic_async_fun(backend: ServiceBackend): + b = Batch(backend=backend) + + j = b.new_python_job() + j.call(asyncio.sleep, 1) + + res = b.run() + assert res + batch_status = res.status() + assert batch_status['state'] == 'success', str((res.debug_info())) + + +def test_async_fun_returns_value(backend: ServiceBackend): + b = Batch(backend=backend) + + async def foo(i, j): + await asyncio.sleep(1) + return i * j + + j = b.new_python_job() + result = j.call(foo, 2, 3) + + j = b.new_job() + j.command(f'cat {result.as_str()}') + + res = b.run() + assert res + batch_status = res.status() + assert batch_status['state'] == 'success', str((batch_status, res.debug_info())) + job_log_2 = res.get_job_log(2) + assert job_log_2['main'] == "6\n", str((job_log_2, res.debug_info())) + + +def test_specify_job_region(backend: ServiceBackend): + b = batch(backend, cancel_after_n_failures=1) + j = b.new_job('region') + possible_regions = backend.supported_regions() + j.regions(possible_regions) + j.command('true') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_always_copy_output(backend: ServiceBackend, output_tmpdir: str): + output_path = os.path.join(output_tmpdir, 'test_always_copy_output.txt') + + b = batch(backend) + j = b.new_job() + j.always_copy_output() + j.command(f'echo "hello" > {j.ofile} && false') + + b.write_output(j.ofile, output_path) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + b2 = batch(backend) + input = b2.read_input(output_path) + file_exists_j = b2.new_job() + file_exists_j.command(f'cat {input}') + + res = b2.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(1)['main'] == "hello\n", str(res.debug_info()) + + +def test_no_copy_output_on_failure(backend: ServiceBackend, output_tmpdir: str): + output_path = os.path.join(output_tmpdir, 'test_no_copy_output.txt') + + b = batch(backend) + j = b.new_job() + j.command(f'echo "hello" > {j.ofile} && false') + + b.write_output(j.ofile, output_path) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + b2 = batch(backend) + input = b2.read_input(output_path) + file_exists_j = b2.new_job() + file_exists_j.command(f'cat {input}') + + res = b2.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + +def test_update_batch(backend: ServiceBackend): + b = batch(backend) + j = b.new_job() + j.command('true') + res = b.run() + assert res + + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + j2 = b.new_job() + j2.command('true') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_update_batch_with_dependencies(backend: ServiceBackend): + b = batch(backend) + j1 = b.new_job() + j1.command('true') + j2 = b.new_job() + j2.command('false') + res = b.run() + assert res + + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + j3 = b.new_job() + j3.command('true') + j3.depends_on(j1) + + j4 = b.new_job() + j4.command('true') + j4.depends_on(j2) + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + assert res.get_job(3).status()['state'] == 'Success', str((res_status, res.debug_info())) + assert res.get_job(4).status()['state'] == 'Cancelled', str((res_status, res.debug_info())) + + +def test_update_batch_with_python_job_dependencies(backend: ServiceBackend): + b = batch(backend) + + async def foo(i, j): + await asyncio.sleep(1) + return i * j + + j1 = b.new_python_job() + j1.call(foo, 2, 3) + + res = b.run() + assert res + batch_status = res.status() + assert batch_status['state'] == 'success', str((batch_status, res.debug_info())) + + j2 = b.new_python_job() + j2.call(foo, 2, 3) + + res = b.run() + assert res + batch_status = res.status() + assert batch_status['state'] == 'success', str((batch_status, res.debug_info())) + + j3 = b.new_python_job() + j3.depends_on(j2) + j3.call(foo, 2, 3) + + res = b.run() + assert res + batch_status = res.status() + assert batch_status['state'] == 'success', str((batch_status, res.debug_info())) + + +def test_update_batch_from_batch_id(backend: ServiceBackend): + b = batch(backend) + j = b.new_job() + j.command('true') + res = b.run() + assert res + + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + b2 = Batch.from_batch_id(res.id, backend=b._backend) + j2 = b2.new_job() + j2.command('true') + res = b2.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +async def test_python_job_with_kwarg(fs: RouterAsyncFS, backend: ServiceBackend, output_tmpdir: str): + def foo(*, kwarg): + return kwarg + + b = batch(backend, default_python_image=PYTHON_DILL_IMAGE) + j = b.new_python_job() + r = j.call(foo, kwarg='hello world') + + output_path = os.path.join(output_tmpdir, 'test_python_job_with_kwarg') + b.write_output(r.as_json(), output_path) + res = b.run() + assert isinstance(res, bc.Batch) + + assert res.status()['state'] == 'success', str((res, res.debug_info())) + assert orjson.loads(await fs.read(output_path)) == 'hello world' + + +def test_tuple_recursive_resource_extraction_in_python_jobs(backend: ServiceBackend): + b = batch(backend, default_python_image=PYTHON_DILL_IMAGE) + + def write(paths): + if not isinstance(paths, tuple): + raise ValueError('paths must be a tuple') + for i, path in enumerate(paths): + with open(path, 'w') as f: + f.write(f'{i}') + + head = b.new_python_job() + head.call(write, (head.ofile1, head.ofile2)) + + tail = b.new_bash_job() + tail.command(f'cat {head.ofile1}') + tail.command(f'cat {head.ofile2}') + + res = b.run() + assert res + assert tail._job_id + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(tail._job_id)['main'] == '01', str(res.debug_info()) + + +def test_list_recursive_resource_extraction_in_python_jobs(backend: ServiceBackend): + b = batch(backend, default_python_image=PYTHON_DILL_IMAGE) + + def write(paths): + for i, path in enumerate(paths): + with open(path, 'w') as f: + f.write(f'{i}') + + head = b.new_python_job() + head.call(write, [head.ofile1, head.ofile2]) + + tail = b.new_bash_job() + tail.command(f'cat {head.ofile1}') + tail.command(f'cat {head.ofile2}') + + res = b.run() + assert res + assert tail._job_id + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(tail._job_id)['main'] == '01', str(res.debug_info()) + + +def test_dict_recursive_resource_extraction_in_python_jobs(backend: ServiceBackend): + b = batch(backend, default_python_image=PYTHON_DILL_IMAGE) + + def write(kwargs): + for k, v in kwargs.items(): + with open(v, 'w') as f: + f.write(k) + + head = b.new_python_job() + head.call(write, {'a': head.ofile1, 'b': head.ofile2}) + + tail = b.new_bash_job() + tail.command(f'cat {head.ofile1}') + tail.command(f'cat {head.ofile2}') + + res = b.run() + assert res + assert tail._job_id + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(tail._job_id)['main'] == 'ab', str(res.debug_info()) + + +def test_wait_on_empty_batch_update(backend: ServiceBackend): + b = batch(backend) + b.run(wait=True) + b.run(wait=True) + + +def test_non_spot_job(backend: ServiceBackend): + b = batch(backend) + j = b.new_job() + j.spot(False) + j.command('echo hello') + res = b.run() + assert res + assert res.get_job(1).status()['spec']['resources']['preemptible'] == False + + +def test_spot_unspecified_job(backend: ServiceBackend): + b = batch(backend) + j = b.new_job() + j.command('echo hello') + res = b.run() + assert res + assert res.get_job(1).status()['spec']['resources']['preemptible'] == True + + +def test_spot_true_job(backend: ServiceBackend): + b = batch(backend) + j = b.new_job() + j.spot(True) + j.command('echo hello') + res = b.run() + assert res + assert res.get_job(1).status()['spec']['resources']['preemptible'] == True + + +def test_non_spot_batch(backend: ServiceBackend): + b = batch(backend, default_spot=False) + j1 = b.new_job() + j1.command('echo hello') + j2 = b.new_job() + j2.command('echo hello') + j3 = b.new_job() + j3.spot(True) + j3.command('echo hello') + res = b.run() + assert res + assert res.get_job(1).status()['spec']['resources']['preemptible'] == False + assert res.get_job(2).status()['spec']['resources']['preemptible'] == False + assert res.get_job(3).status()['spec']['resources']['preemptible'] == True + + +def test_local_file_paths_error(backend: ServiceBackend): + b = batch(backend) + b.new_job() + for input in ["hi.txt", "~/hello.csv", "./hey.tsv", "/sup.json", "file://yo.yaml"]: + with pytest.raises(ValueError) as e: + b.read_input(input) + assert str(e.value).startswith("Local filepath detected") + + +@skip_in_azure +def test_validate_cloud_storage_policy(backend, monkeypatch): + # buckets do not exist (bucket names can't contain the string "google" per + # https://cloud.google.com/storage/docs/buckets) + fake_bucket1 = "google" + fake_bucket2 = "google1" + no_bucket_error = "bucket does not exist" + # bucket exists, but account does not have permissions on it + no_perms_bucket = "test" + no_perms_error = "does not have storage.buckets.get access" + # bucket exists and account has permissions, but is set to use cold storage by default + cold_bucket = "hail-test-cold-storage" + cold_error = "configured to use cold storage by default" + fake_uri1, fake_uri2, no_perms_uri, cold_uri = [ + f"gs://{bucket}/test" for bucket in [fake_bucket1, fake_bucket2, no_perms_bucket, cold_bucket] + ] + + def _test_raises(exception_type, exception_msg, func): + with pytest.raises(exception_type) as e: + func() + assert exception_msg in str(e.value) + + def _test_raises_no_bucket_error(remote_tmpdir, arg=None): + _test_raises( + ClientResponseError, + no_bucket_error, + lambda: ServiceBackend(remote_tmpdir=remote_tmpdir, gcs_bucket_allow_list=arg), + ) + + def _test_raises_cold_error(func): + _test_raises(ValueError, cold_error, func) + + # no configuration, nonexistent buckets error + _test_raises_no_bucket_error(fake_uri1) + _test_raises_no_bucket_error(fake_uri2) + + # no configuration, no perms bucket errors + _test_raises(ClientResponseError, no_perms_error, lambda: ServiceBackend(remote_tmpdir=no_perms_uri)) + + # no configuration, cold bucket errors + _test_raises_cold_error(lambda: ServiceBackend(remote_tmpdir=cold_uri)) + b = batch(backend) + _test_raises_cold_error(lambda: b.read_input(cold_uri)) + j = b.new_job() + j.command(f"echo hello > {j.ofile}") + _test_raises_cold_error(lambda: b.write_output(j.ofile, cold_uri)) + + # hailctl config, allowlisted nonexistent buckets don't error + base_config = get_user_config() + local_config = ConfigParser() + local_config.read_dict( + { + **{section: {key: val for key, val in base_config[section].items()} for section in base_config.sections()}, + **{"gcs": {"bucket_allow_list": f"{fake_bucket1},{fake_bucket2}"}}, + } + ) + + def _get_user_config(): + return local_config + + monkeypatch.setattr(user_config, "get_user_config", _get_user_config) + ServiceBackend(remote_tmpdir=fake_uri1) + ServiceBackend(remote_tmpdir=fake_uri2) + + # environment variable config, only allowlisted nonexistent buckets don't error + monkeypatch.setenv("HAIL_GCS_BUCKET_ALLOW_LIST", fake_bucket2) + _test_raises_no_bucket_error(fake_uri1) + ServiceBackend(remote_tmpdir=fake_uri2) + + # arg to constructor config, only allowlisted nonexistent buckets don't error + arg = [fake_bucket1] + ServiceBackend(remote_tmpdir=fake_uri1, gcs_bucket_allow_list=arg) + _test_raises_no_bucket_error(fake_uri2, arg) diff --git a/hail/python/test/hailtop/conftest.py b/hail/python/test/hailtop/conftest.py index 9b612b07d6d..d8c573e72b4 100644 --- a/hail/python/test/hailtop/conftest.py +++ b/hail/python/test/hailtop/conftest.py @@ -1,10 +1,19 @@ +import asyncio import hashlib import os - import pytest -def pytest_collection_modifyitems(config, items): +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop() + try: + yield loop + finally: + loop.close() + + +def pytest_collection_modifyitems(items): n_splits = int(os.environ.get('HAIL_RUN_IMAGE_SPLITS', '1')) split_index = int(os.environ.get('HAIL_RUN_IMAGE_SPLIT_INDEX', '-1')) if n_splits <= 1: diff --git a/hail/python/test/hailtop/inter_cloud/test_copy.py b/hail/python/test/hailtop/inter_cloud/test_copy.py index 3bd3ce157a9..98655f9822f 100644 --- a/hail/python/test/hailtop/inter_cloud/test_copy.py +++ b/hail/python/test/hailtop/inter_cloud/test_copy.py @@ -18,14 +18,6 @@ from .copy_test_specs import COPY_TEST_SPECS -@pytest.fixture(scope='module') -def event_loop(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - yield loop - loop.close() - - # This fixture is for test_copy_behavior. It runs a series of copy # test "specifications" by calling run_test_spec. The set of # specifications is enumerated by @@ -125,7 +117,6 @@ async def copy_test_context(request, router_filesystem: Tuple[asyncio.Semaphore, yield sema, fs, src_base, dest_base -@pytest.mark.asyncio async def test_copy_behavior(copy_test_context, test_spec): sema, fs, src_base, dest_base = copy_test_context @@ -182,7 +173,6 @@ def __exit__(self, type, value, traceback): return True -@pytest.mark.asyncio async def test_copy_doesnt_exist(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -190,7 +180,6 @@ async def test_copy_doesnt_exist(copy_test_context): await Copier.copy(fs, sema, Transfer(f'{src_base}a', dest_base)) -@pytest.mark.asyncio async def test_copy_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -201,7 +190,6 @@ async def test_copy_file(copy_test_context): await expect_file(fs, f'{dest_base}a', 'src/a') -@pytest.mark.asyncio async def test_copy_large_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -217,7 +205,6 @@ async def test_copy_large_file(copy_test_context): assert copy_contents == contents -@pytest.mark.asyncio async def test_copy_rename_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -228,7 +215,6 @@ async def test_copy_rename_file(copy_test_context): await expect_file(fs, f'{dest_base}x', 'src/a') -@pytest.mark.asyncio async def test_copy_rename_file_dest_target_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -239,7 +225,6 @@ async def test_copy_rename_file_dest_target_file(copy_test_context): await expect_file(fs, f'{dest_base}x', 'src/a') -@pytest.mark.asyncio async def test_copy_file_dest_target_directory_doesnt_exist(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -250,7 +235,6 @@ async def test_copy_file_dest_target_directory_doesnt_exist(copy_test_context): await expect_file(fs, f'{dest_base}x/a', 'src/a') -@pytest.mark.asyncio async def test_overwrite_rename_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -262,7 +246,6 @@ async def test_overwrite_rename_file(copy_test_context): await expect_file(fs, f'{dest_base}x', 'src/a') -@pytest.mark.asyncio async def test_copy_rename_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -274,7 +257,6 @@ async def test_copy_rename_dir(copy_test_context): await expect_file(fs, f'{dest_base}x/subdir/file2', 'src/a/subdir/file2') -@pytest.mark.asyncio async def test_copy_rename_dir_dest_is_target(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -286,7 +268,6 @@ async def test_copy_rename_dir_dest_is_target(copy_test_context): await expect_file(fs, f'{dest_base}x/subdir/file2', 'src/a/subdir/file2') -@pytest.mark.asyncio async def test_overwrite_rename_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -300,7 +281,6 @@ async def test_overwrite_rename_dir(copy_test_context): await expect_file(fs, f'{dest_base}x/file3', 'dest/x/file3') -@pytest.mark.asyncio async def test_copy_file_dest_trailing_slash_target_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -311,7 +291,6 @@ async def test_copy_file_dest_trailing_slash_target_dir(copy_test_context): await expect_file(fs, f'{dest_base}a', 'src/a') -@pytest.mark.asyncio async def test_copy_file_dest_target_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -322,7 +301,6 @@ async def test_copy_file_dest_target_dir(copy_test_context): await expect_file(fs, f'{dest_base}a', 'src/a') -@pytest.mark.asyncio async def test_copy_file_dest_target_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -333,7 +311,6 @@ async def test_copy_file_dest_target_file(copy_test_context): await expect_file(fs, f'{dest_base}a', 'src/a') -@pytest.mark.asyncio async def test_copy_dest_target_file_is_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -345,7 +322,6 @@ async def test_copy_dest_target_file_is_dir(copy_test_context): ) -@pytest.mark.asyncio async def test_overwrite_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -357,7 +333,6 @@ async def test_overwrite_file(copy_test_context): await expect_file(fs, f'{dest_base}a', 'src/a') -@pytest.mark.asyncio async def test_copy_file_src_trailing_slash(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -367,7 +342,6 @@ async def test_copy_file_src_trailing_slash(copy_test_context): await Copier.copy(fs, sema, Transfer(f'{src_base}a/', dest_base)) -@pytest.mark.asyncio async def test_copy_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -379,7 +353,6 @@ async def test_copy_dir(copy_test_context): await expect_file(fs, f'{dest_base}a/subdir/file2', 'src/a/subdir/file2') -@pytest.mark.asyncio async def test_overwrite_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -393,7 +366,6 @@ async def test_overwrite_dir(copy_test_context): await expect_file(fs, f'{dest_base}a/file3', 'dest/a/file3') -@pytest.mark.asyncio async def test_copy_multiple(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -406,7 +378,6 @@ async def test_copy_multiple(copy_test_context): await expect_file(fs, f'{dest_base}b', 'src/b') -@pytest.mark.asyncio async def test_copy_multiple_dest_target_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -421,7 +392,6 @@ async def test_copy_multiple_dest_target_file(copy_test_context): ) -@pytest.mark.asyncio async def test_copy_multiple_dest_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -433,7 +403,6 @@ async def test_copy_multiple_dest_file(copy_test_context): await Copier.copy(fs, sema, Transfer([f'{src_base}a', f'{src_base}b'], f'{dest_base}x')) -@pytest.mark.asyncio async def test_file_overwrite_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -445,7 +414,6 @@ async def test_file_overwrite_dir(copy_test_context): ) -@pytest.mark.asyncio async def test_file_and_directory_error( router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str ): @@ -461,7 +429,6 @@ async def test_file_and_directory_error( await Copier.copy(fs, sema, Transfer(f'{src_base}a', dest_base.rstrip('/'))) -@pytest.mark.asyncio async def test_copy_src_parts(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -486,7 +453,6 @@ async def collect_files(it: AsyncIterator[FileListEntry]) -> List[str]: return [await x.url() async for x in it] -@pytest.mark.asyncio async def test_file_and_directory_error_with_slash_empty_file( router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str ): @@ -526,7 +492,6 @@ async def test_file_and_directory_error_with_slash_empty_file( await expect_file(fs, exp_dest, 'foo') -@pytest.mark.asyncio async def test_file_and_directory_error_with_slash_non_empty_file_for_google_non_recursive( router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]] ): @@ -544,7 +509,6 @@ async def test_file_and_directory_error_with_slash_non_empty_file_for_google_non await collect_files(await fs.listfiles(f'{src_base}not-empty/')) -@pytest.mark.asyncio async def test_file_and_directory_error_with_slash_non_empty_file( router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str ): @@ -588,7 +552,6 @@ async def test_file_and_directory_error_with_slash_non_empty_file( await Copier.copy(fs, sema, Transfer(f'{src_base}', dest_base.rstrip('/'), treat_dest_as=transfer_type)) -@pytest.mark.asyncio async def test_file_and_directory_error_with_slash_non_empty_file_only_for_google_non_recursive( router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]] ): @@ -612,7 +575,6 @@ async def test_file_and_directory_error_with_slash_non_empty_file_only_for_googl await collect_files(await fs.listfiles(f'{dest_base}empty-only/')) -@pytest.mark.asyncio async def test_file_and_directory_error_with_slash_empty_file_only( router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str ): @@ -638,7 +600,6 @@ async def test_file_and_directory_error_with_slash_empty_file_only( await Copier.copy(fs, sema, Transfer(f'{src_base}', dest_base.rstrip('/'), treat_dest_as=transfer_type)) -@pytest.mark.asyncio async def test_file_and_directory_error_with_slash_non_empty_file_only_google_non_recursive( router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]] ): @@ -655,7 +616,6 @@ async def test_file_and_directory_error_with_slash_non_empty_file_only_google_no await collect_files(await fs.listfiles(f'{src_base}not-empty-file-w-slash/')) -@pytest.mark.asyncio async def test_file_and_directory_error_with_slash_non_empty_file_only( router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str ): diff --git a/hail/python/test/hailtop/inter_cloud/test_diff.py b/hail/python/test/hailtop/inter_cloud/test_diff.py index 12fb62b70d6..a89035dbb9e 100644 --- a/hail/python/test/hailtop/inter_cloud/test_diff.py +++ b/hail/python/test/hailtop/inter_cloud/test_diff.py @@ -12,14 +12,6 @@ from hailtop.aiotools.router_fs import RouterAsyncFS -@pytest.fixture(scope='module') -def event_loop(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - yield loop - loop.close() - - @pytest.fixture(scope='module') async def router_filesystem() -> AsyncIterator[Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]]]: token = secrets.token_hex(16) diff --git a/hail/python/test/hailtop/inter_cloud/test_fs.py b/hail/python/test/hailtop/inter_cloud/test_fs.py index d15dcdfaf82..e83ba02fe2a 100644 --- a/hail/python/test/hailtop/inter_cloud/test_fs.py +++ b/hail/python/test/hailtop/inter_cloud/test_fs.py @@ -515,7 +515,7 @@ async def test_statfile_creation_and_modified_time(filesystem: Tuple[asyncio.Sem status = await fs.statfile(file) if isinstance(fs, RouterAsyncFS): - is_local = isinstance(fs._get_fs(file), LocalAsyncFS) + is_local = isinstance(await fs._get_fs(file), LocalAsyncFS) else: is_local = isinstance(fs, LocalAsyncFS) diff --git a/hail/python/test/hailtop/test_aiogoogle.py b/hail/python/test/hailtop/test_aiogoogle.py index f094d2aac1f..462630f8d79 100644 --- a/hail/python/test/hailtop/test_aiogoogle.py +++ b/hail/python/test/hailtop/test_aiogoogle.py @@ -49,7 +49,6 @@ def test_bucket_path_parsing(): assert bucket == 'foo' and prefix == 'bar/baz' -@pytest.mark.asyncio async def test_get_object_metadata(bucket_and_temporary_file): bucket, file = bucket_and_temporary_file @@ -67,7 +66,6 @@ async def upload(): assert int(metadata['size']) == 3 -@pytest.mark.asyncio async def test_get_object_headers(bucket_and_temporary_file): bucket, file = bucket_and_temporary_file @@ -85,7 +83,6 @@ async def upload(): assert await f.read() == b'foo' -@pytest.mark.asyncio async def test_compose(bucket_and_temporary_file): bucket, file = bucket_and_temporary_file @@ -107,7 +104,6 @@ async def upload(i, b): assert actual == expected -@pytest.mark.asyncio async def test_multi_part_create_many_two_level_merge(gs_filesystem): # This is a white-box test. compose has a maximum of 32 inputs, # so if we're composing more than 32 parts, the @@ -144,7 +140,6 @@ async def create_part(i): raise AssertionError('uncaught cancelled error') from err -@pytest.mark.asyncio async def test_weird_urls(gs_filesystem): _, fs, base = gs_filesystem