Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use kwargs in CLI #2899

Merged
merged 5 commits into from Jul 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Expand Up @@ -24,7 +24,7 @@ install:
script:
- if [[ $TESTS == true ]]; then source continuous_integration/travis/run_tests.sh ; fi
- if [[ $LINT == true ]]; then pip install flake8 ; flake8 distributed ; fi
- if [[ $LINT == true ]]; then pip install black; black distributed --check; fi
- if [[ $LINT == true ]]; then pip install git+https://github.com/psf/black@cad4138050b86d1c8570b926883e32f7465c2880; black distributed --check; fi

after_success:
- if [[ $COVERAGE == true ]]; then coverage report; pip install -q coveralls ; coveralls ; fi
Expand Down
14 changes: 2 additions & 12 deletions distributed/cli/dask_scheduler.py
Expand Up @@ -136,17 +136,12 @@ def main(
dashboard_prefix,
use_xheaders,
pid_file,
scheduler_file,
interface,
protocol,
local_directory,
preload,
preload_argv,
tls_ca_file,
tls_cert,
tls_key,
dashboard_address,
idle_timeout,
**kwargs
):
g0, g1, g2 = gc.get_threshold() # https://github.com/dask/distributed/issues/1653
gc.set_threshold(g0 * 3, g1 * 3, g2 * 3)
Expand Down Expand Up @@ -217,17 +212,12 @@ def del_pid_file():

scheduler = Scheduler(
loop=loop,
scheduler_file=scheduler_file,
security=sec,
host=host,
port=port,
interface=interface,
protocol=protocol,
dashboard_address=dashboard_address if dashboard else None,
service_kwargs={"dashboard": {"prefix": dashboard_prefix}},
idle_timeout=idle_timeout,
preload=preload,
preload_argv=preload_argv,
**kwargs,
)
logger.info("Local Directory: %26s", local_directory)
logger.info("-" * 47)
Expand Down
29 changes: 3 additions & 26 deletions distributed/cli/dask_worker.py
Expand Up @@ -11,7 +11,6 @@
import click
import dask
from distributed import Nanny, Worker
from distributed.utils import parse_timedelta
from distributed.security import Security
from distributed.cli.utils import check_python_3, install_signal_handlers
from distributed.comm import get_address_host_port
Expand Down Expand Up @@ -199,25 +198,18 @@ def main(
nprocs,
nanny,
name,
memory_limit,
pid_file,
reconnect,
resources,
dashboard,
bokeh,
bokeh_port,
local_directory,
scheduler_file,
interface,
protocol,
death_timeout,
preload,
preload_argv,
dashboard_prefix,
tls_ca_file,
tls_cert,
tls_key,
dashboard_address,
**kwargs
):
g0, g1, g2 = gc.get_threshold() # https://github.com/dask/distributed/issues/1653
gc.set_threshold(g0 * 3, g1 * 3, g2 * 3)
Expand Down Expand Up @@ -314,8 +306,6 @@ def del_pid_file():

atexit.register(del_pid_file)

services = {}

if resources:
resources = resources.replace(",", " ").split()
resources = dict(pair.split("=") for pair in resources)
Expand All @@ -326,10 +316,9 @@ def del_pid_file():
loop = IOLoop.current()

if nanny:
kwargs = {"worker_port": worker_port, "listen_address": listen_address}
kwargs.update({"worker_port": worker_port, "listen_address": listen_address})
t = Nanny
else:
kwargs = {}
if nanny_port:
kwargs["service_ports"] = {"nanny": nanny_port}
t = Worker
Expand All @@ -344,33 +333,21 @@ def del_pid_file():
"dask-worker SCHEDULER_ADDRESS:8786"
)

if death_timeout is not None:
death_timeout = parse_timedelta(death_timeout, "s")

nannies = [
t(
scheduler,
scheduler_file=scheduler_file,
nthreads=nthreads,
services=services,
loop=loop,
resources=resources,
memory_limit=memory_limit,
reconnect=reconnect,
local_directory=local_directory,
death_timeout=death_timeout,
preload=preload,
preload_argv=preload_argv,
security=sec,
contact_address=contact_address,
interface=interface,
protocol=protocol,
host=host,
port=port,
dashboard_address=dashboard_address if dashboard else None,
service_kwargs={"dashboard": {"prefix": dashboard_prefix}},
name=name if nprocs == 1 or not name else name + "-" + str(i),
**kwargs
**kwargs,
)
for i in range(nprocs)
]
Expand Down
4 changes: 3 additions & 1 deletion distributed/comm/ucx.py
Expand Up @@ -121,7 +121,9 @@ async def read(self, deserializers=("cuda", "dask", "pickle", "error")):
deserializers = ("cuda", "dask", "pickle", "error")
resp = await self.ep.recv_future()
obj = ucp.get_obj_from_msg(resp)
nframes, = struct.unpack("Q", obj[:8]) # first eight bytes for number of frames
(nframes,) = struct.unpack(
"Q", obj[:8]
) # first eight bytes for number of frames

gpu_frame_msg = obj[
8 : 8 + nframes
Expand Down
2 changes: 1 addition & 1 deletion distributed/protocol/cupy.py
Expand Up @@ -31,7 +31,7 @@ def serialize_cupy_ndarray(x):

@cuda_deserialize.register(cupy.ndarray)
def deserialize_cupy_array(header, frames):
frame, = frames
(frame,) = frames
# TODO: put this in ucx... as a kind of "fixup"
try:
frame.typestr = header["typestr"]
Expand Down
2 changes: 1 addition & 1 deletion distributed/protocol/numba.py
Expand Up @@ -36,7 +36,7 @@ def serialize_numba_ndarray(x):

@cuda_deserialize.register(numba.cuda.devicearray.DeviceNDArray)
def deserialize_numba_ndarray(header, frames):
frame, = frames
(frame,) = frames
# TODO: put this in ucx... as a kind of "fixup"
if isinstance(frame, bytes):
import numpy as np
Expand Down