Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[batch] re-allow sending kwargs to python jobs #13505

Merged
merged 7 commits into from Sep 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 24 additions & 26 deletions devbin/generate_gcp_ar_cleanup_policy.py
Expand Up @@ -12,15 +12,17 @@ def to_dict(self):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't entirely understand why pyright runs on everything but it does and it dislikes the formatting of this file.


class DeletePolicy(CleanupPolicy):
def __init__(self,
name: str,
tag_state: str,
*,
tag_prefixes: Optional[List[str]] = None,
version_name_prefixes: Optional[List[str]] = None,
package_name_prefixes: Optional[List[str]] = None,
older_than: Optional[str] = None,
newer_than: Optional[str] = None):
def __init__(
self,
name: str,
tag_state: str,
*,
tag_prefixes: Optional[List[str]] = None,
version_name_prefixes: Optional[List[str]] = None,
package_name_prefixes: Optional[List[str]] = None,
older_than: Optional[str] = None,
newer_than: Optional[str] = None
):
self.name = name
self.tag_state = tag_state
self.tag_prefixes = tag_prefixes
Expand All @@ -46,15 +48,17 @@ def to_dict(self):


class ConditionalKeepPolicy(CleanupPolicy):
def __init__(self,
name: str,
tag_state: str,
*,
tag_prefixes: Optional[List[str]] = None,
version_name_prefixes: Optional[List[str]] = None,
package_name_prefixes: Optional[List[str]] = None,
older_than: Optional[str] = None,
newer_than: Optional[str] = None):
def __init__(
self,
name: str,
tag_state: str,
*,
tag_prefixes: Optional[List[str]] = None,
version_name_prefixes: Optional[List[str]] = None,
package_name_prefixes: Optional[List[str]] = None,
older_than: Optional[str] = None,
newer_than: Optional[str] = None
):
self.name = name
self.tag_state = tag_state
self.tag_prefixes = tag_prefixes
Expand All @@ -80,10 +84,7 @@ def to_dict(self):


class MostRecentVersionKeepPolicy(CleanupPolicy):
def __init__(self,
name: str,
package_name_prefixes: List[str],
keep_count: int):
def __init__(self, name: str, package_name_prefixes: List[str], keep_count: int):
self.name = name
self.package_name_prefixes = package_name_prefixes
self.keep_count = keep_count
Expand All @@ -92,10 +93,7 @@ def to_dict(self):
data = {
'name': self.name,
'action': {'type': 'Keep'},
'mostRecentVersions': {
'packageNamePrefixes': self.package_name_prefixes,
'keepCount': self.keep_count
}
'mostRecentVersions': {'packageNamePrefixes': self.package_name_prefixes, 'keepCount': self.keep_count},
}
return data

Expand Down
51 changes: 35 additions & 16 deletions hail/python/hailtop/batch/job.py
Expand Up @@ -5,7 +5,7 @@
import textwrap
import warnings
from shlex import quote as shq
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast, Literal
from typing_extensions import Self

import hailtop.batch_client.client as bc
Expand Down Expand Up @@ -878,6 +878,19 @@ async def _compile(self, local_tmpdir, remote_tmpdir, *, dry_run=False):
return True


UnpreparedArg = Union['_resource.ResourceType', List['UnpreparedArg'], Tuple['UnpreparedArg', ...], Dict[str, 'UnpreparedArg'], Any]

PreparedArg = Union[
Tuple[Literal['py_path'], str],
Tuple[Literal['path'], str],
Tuple[Literal['dict_path'], Dict[str, str]],
Tuple[Literal['list'], List['PreparedArg']],
Tuple[Literal['dict'], Dict[str, 'PreparedArg']],
Tuple[Literal['tuple'], Tuple['PreparedArg', ...]],
Tuple[Literal['value'], Any]
]


class PythonJob(Job):
"""
Object representing a single Python job to execute.
Expand Down Expand Up @@ -924,7 +937,7 @@ def __init__(self,
super().__init__(batch, token, name=name, attributes=attributes, shell=None)
self._resources: Dict[str, _resource.Resource] = {}
self._resources_inverse: Dict[_resource.Resource, str] = {}
self._function_calls: List[Tuple[_resource.PythonResult, int, Tuple[Any, ...], Dict[str, Any]]] = []
self._function_calls: List[Tuple[_resource.PythonResult, int, Tuple[UnpreparedArg, ...], Dict[str, UnpreparedArg]]] = []
self.n_results = 0

def _get_python_resource(self, item: str) -> '_resource.PythonResult':
Expand Down Expand Up @@ -970,7 +983,7 @@ def image(self, image: str) -> 'PythonJob':
self._image = image
return self

def call(self, unapplied: Callable, *args, **kwargs) -> '_resource.PythonResult':
def call(self, unapplied: Callable, *args: UnpreparedArg, **kwargs: UnpreparedArg) -> '_resource.PythonResult':
"""Execute a Python function.

Examples
Expand Down Expand Up @@ -1148,28 +1161,32 @@ def handle_args(r):
return result

async def _compile(self, local_tmpdir, remote_tmpdir, *, dry_run=False):
def prepare_argument_for_serialization(arg):
def preserialize(arg: UnpreparedArg) -> PreparedArg:
if isinstance(arg, _resource.PythonResult):
return ('py_path', arg._get_path(local_tmpdir))
if isinstance(arg, _resource.ResourceFile):
return ('path', arg._get_path(local_tmpdir))
if isinstance(arg, _resource.ResourceGroup):
return ('dict_path', {name: resource._get_path(local_tmpdir)
for name, resource in arg._resources.items()})
if isinstance(arg, (list, tuple)):
return ('value', [prepare_argument_for_serialization(elt) for elt in arg])
if isinstance(arg, list):
return ('list', [preserialize(elt) for elt in arg])
if isinstance(arg, tuple):
return ('tuple', tuple((preserialize(elt) for elt in arg)))
if isinstance(arg, dict):
return ('value', {k: prepare_argument_for_serialization(v) for k, v in arg.items()})
return ('dict', {k: preserialize(v) for k, v in arg.items()})
return ('value', arg)

for i, (result, unapplied_id, args, kwargs) in enumerate(self._function_calls):
func_file = self._batch._python_function_files[unapplied_id]

prepared_args = prepare_argument_for_serialization(args)[1]
prepared_kwargs = prepare_argument_for_serialization(kwargs)[1]
preserialized_args = [preserialize(arg) for arg in args]
del args
preserialized_kwargs = {keyword: preserialize(arg) for keyword, arg in kwargs.items()}
del kwargs

args_file = await self._batch._serialize_python_to_input_file(
os.path.dirname(result._get_path(remote_tmpdir)), "args", i, (prepared_args, prepared_kwargs), dry_run
os.path.dirname(result._get_path(remote_tmpdir)), "args", i, (preserialized_args, preserialized_kwargs), dry_run
)

json_write, str_write, repr_write = [
Expand All @@ -1191,14 +1208,16 @@ def prepare_argument_for_serialization(arg):

def deserialize_argument(arg):
typ, val = arg
if typ == 'value' and isinstance(val, dict):
return {{k: deserialize_argument(v) for k, v in val.items()}}
if typ == 'value' and isinstance(val, (list, tuple)):
return [deserialize_argument(elt) for elt in val]
if typ == 'py_path':
return dill.load(open(val, 'rb'))
if typ in ('path', 'dict_path'):
return val
if typ == 'list':
return [deserialize_argument(elt) for elt in val]
if typ == 'tuple':
return tuple((deserialize_argument(elt) for elt in val))
if typ == 'dict':
return {{k: deserialize_argument(v) for k, v in val.items()}}
assert typ == 'value'
return val

Expand Down Expand Up @@ -1226,8 +1245,8 @@ def deserialize_argument(arg):

unapplied = self._batch._python_function_defs[unapplied_id]
self._user_code.append(textwrap.dedent(inspect.getsource(unapplied)))
args_str = ', '.join([f'{arg!r}' for _, arg in prepared_args])
kwargs_str = ', '.join([f'{k}={v!r}' for k, (_, v) in kwargs.items()])
args_str = ', '.join([f'{arg!r}' for _, arg in preserialized_args])
kwargs_str = ', '.join([f'{k}={v!r}' for k, (_, v) in preserialized_kwargs.items()])
separator = ', ' if args_str and kwargs_str else ''
func_call = f'{unapplied.__name__}({args_str}{separator}{kwargs_str})'
self._user_code.append(self._interpolate_command(func_call, allow_python_results=True))
Expand Down
5 changes: 4 additions & 1 deletion hail/python/hailtop/batch/resource.py
@@ -1,5 +1,5 @@
import abc
from typing import Optional, Set, cast
from typing import Optional, Set, cast, Union

from . import job # pylint: disable=cyclic-import
from .exceptions import BatchException
Expand Down Expand Up @@ -448,3 +448,6 @@ def __str__(self):

def __repr__(self):
return self._uid # pylint: disable=no-member


ResourceType = Union[PythonResult, ResourceFile, ResourceGroup]
44 changes: 44 additions & 0 deletions hail/python/test/hailtop/batch/test_batch.py
Expand Up @@ -10,7 +10,10 @@
from shlex import quote as shq
import uuid
import re
import orjson

import hailtop.fs as hfs
import hailtop.batch_client.client as bc
from hailtop import pip_version
from hailtop.batch import Batch, ServiceBackend, LocalBackend, ResourceGroup
from hailtop.batch.resource import JobResourceFile
Expand Down Expand Up @@ -1291,6 +1294,47 @@ def test_update_batch_from_batch_id(self):
res_status = res.status()
assert res_status['state'] == 'success', str((res_status, res.debug_info()))

def test_python_job_with_kwarg(self):
def foo(*, kwarg):
return kwarg

b = self.batch(default_python_image=PYTHON_DILL_IMAGE)
j = b.new_python_job()
r = j.call(foo, kwarg='hello world')

output_path = f'{self.cloud_output_dir}/test_python_job_with_kwarg'
b.write_output(r.as_json(), output_path)
res = b.run()
assert isinstance(res, bc.Batch)

assert res.status()['state'] == 'success', str((res, res.debug_info()))
with hfs.open(output_path) as f:
assert orjson.loads(f.read()) == 'hello world'

def test_tuple_recursive_resource_extraction_in_python_jobs(self):
b = self.batch(default_python_image=PYTHON_DILL_IMAGE)

def write(paths):
if not isinstance(paths, tuple):
raise ValueError('paths must be a tuple')
for i, path in enumerate(paths):
with open(path, 'w') as f:
f.write(f'{i}')

head = b.new_python_job()
head.call(write, (head.ofile1, head.ofile2))

tail = b.new_bash_job()
tail.command(f'cat {head.ofile1}')
tail.command(f'cat {head.ofile2}')

res = b.run()
assert res
assert tail._job_id
res_status = res.status()
assert res_status['state'] == 'success', str((res_status, res.debug_info()))
assert res.get_job_log(tail._job_id)['main'] == '01', str(res.debug_info())

def test_list_recursive_resource_extraction_in_python_jobs(self):
b = self.batch(default_python_image=PYTHON_DILL_IMAGE)

Expand Down
2 changes: 1 addition & 1 deletion hail/scripts/test_requester_pays_parsing.py
Expand Up @@ -5,7 +5,7 @@

from hailtop.aiocloud.aiogoogle import get_gcs_requester_pays_configuration
from hailtop.aiocloud.aiogoogle.user_config import get_spark_conf_gcs_requester_pays_configuration, spark_conf_path
from hailtop.config.user_config import ConfigVariable, configuration_of
from hailtop.config import ConfigVariable, configuration_of
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyright also disliked this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and it's right: ConfigVariable is not part of user_config.py

from hailtop.utils.process import check_exec_output

if 'YOU_MAY_OVERWRITE_MY_SPARK_DEFAULTS_CONF_AND_HAILCTL_SETTINGS' not in os.environ:
Expand Down