Skip to content

Commit

Permalink
Support load= keyword for Client.upload_file (#7873)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed May 31, 2023
1 parent d47f11e commit 979adb2
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 15 deletions.
23 changes: 12 additions & 11 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3730,22 +3730,23 @@ def dump_to_file(dask_worker=None):

assert all(len(data) == v for v in response.values())

def upload_file(self, filename, **kwargs):
"""Upload local package to workers
def upload_file(self, filename, load: bool = True):
"""Upload local package to scheduler and workers
This sends a local file up to all worker nodes. This file is placed
into the working directory of the running worker, see config option
This sends a local file up to the scheduler and all worker nodes.
This file is placed into the working directory of each node, see config option
``temporary-directory`` (defaults to :py:func:`tempfile.gettempdir`).
This directory will be added to the Python's system path so any .py,
.egg or .zip files will be importable.
This directory will be added to the Python's system path so any ``.py``,
``.egg`` or ``.zip`` files will be importable.
Parameters
----------
filename : string
Filename of .py, .egg or .zip file to send to workers
**kwargs : dict
Optional keyword arguments for the function
Filename of ``.py``, ``.egg``, or ``.zip`` file to send to workers
load : bool, optional
Whether or not to import the module as part of the upload process.
Defaults to ``True``.
Examples
--------
Expand All @@ -3758,9 +3759,9 @@ def upload_file(self, filename, **kwargs):
async def _():
results = await asyncio.gather(
self.register_scheduler_plugin(
SchedulerUploadFile(filename), name=name
SchedulerUploadFile(filename, load=load), name=name
),
self.register_worker_plugin(UploadFile(filename), name=name),
self.register_worker_plugin(UploadFile(filename, load=load), name=name),
)
return results[1] # Results from workers upload

Expand Down
10 changes: 6 additions & 4 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,17 @@ def _get_plugin_name(plugin: SchedulerPlugin | WorkerPlugin | NannyPlugin) -> st
class SchedulerUploadFile(SchedulerPlugin):
name = "upload_file"

def __init__(self, filepath):
def __init__(self, filepath: str, load: bool = True):
"""
Initialize the plugin by reading in the data from the given file.
"""
self.filename = os.path.basename(filepath)
self.load = load
with open(filepath, "rb") as f:
self.data = f.read()

async def start(self, scheduler: Scheduler) -> None:
await scheduler.upload_file(self.filename, self.data)
await scheduler.upload_file(self.filename, self.data, load=self.load)


class PackageInstall(WorkerPlugin, abc.ABC):
Expand Down Expand Up @@ -599,17 +600,18 @@ class UploadFile(WorkerPlugin):

name = "upload_file"

def __init__(self, filepath):
def __init__(self, filepath: str, load: bool = True):
"""
Initialize the plugin by reading in the data from the given file.
"""
self.filename = os.path.basename(filepath)
self.load = load
with open(filepath, "rb") as f:
self.data = f.read()

async def setup(self, worker):
response = await worker.upload_file(
filename=self.filename, data=self.data, load=True
filename=self.filename, data=self.data, load=self.load
)
assert len(self.data) == response["nbytes"]

Expand Down
14 changes: 14 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,20 @@ def test_upload_file_exception_sync(c):
c.upload_file(fn)


@gen_cluster(client=True)
async def test_upload_file_load(c, s, a, b):
code = "syntax-error!"
with tmp_text("myfile.py", code) as fn:
# Without `load=False` this file would be imported and raise a `SyntaxError`
await c.upload_file(fn, load=False)

# Confirm workers and scheduler got the file
for server in [s, a, b]:
file = pathlib.Path(server.local_directory).joinpath("myfile.py")
assert file.is_file()
assert file.read_text() == code


@gen_cluster(client=True, nthreads=[])
async def test_upload_file_new_worker(c, s):
def g():
Expand Down

0 comments on commit 979adb2

Please sign in to comment.