Skip to content

Commit

Permalink
Add support for kwargs to elfi.set_client (#271)
Browse files Browse the repository at this point in the history
  • Loading branch information
vuolleko committed Jun 8, 2018
1 parent 538a1d9 commit 043c87e
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
=========

dev
---
- Add support for kwargs to elfi.set_client

0.7.1 (2018-04-11)
------------------
- Implemented model selection (elfi.compare_models). See API documentation.
Expand Down
15 changes: 12 additions & 3 deletions elfi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,22 @@ def get_client():
return _client


def set_client(client=None):
"""Set the current ELFI client instance."""
def set_client(client=None, **kwargs):
"""Set the current ELFI client instance.
Parameters
----------
client : ClientBase or str
Instance of a client from ClientBase,
or a string from ['native', 'multiprocessing', 'ipyparallel'].
If string, the respective constructor is called with `kwargs`.
"""
global _client

if isinstance(client, str):
m = importlib.import_module('elfi.clients.{}'.format(client))
client = m.Client()
client = m.Client(**kwargs)

_client = client

Expand Down
4 changes: 2 additions & 2 deletions elfi/clients/ipyparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Client(elfi.client.ClientBase):
http://ipyparallel.readthedocs.io
"""

def __init__(self, ipp_client=None):
def __init__(self, ipp_client=None, **kwargs):
"""Create an ipyparallel client for ELFI.
Parameters
Expand All @@ -34,7 +34,7 @@ def __init__(self, ipp_client=None):
Use this ipyparallel client with ELFI.
"""
self.ipp_client = ipp_client or ipp.Client()
self.ipp_client = ipp_client or ipp.Client(**kwargs)
self.view = self.ipp_client.load_balanced_view()

self.tasks = {}
Expand Down
5 changes: 3 additions & 2 deletions elfi/clients/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def set_as_default():
class Client(elfi.client.ClientBase):
"""Client based on Python's built-in multiprocessing module."""

def __init__(self, num_processes=None):
def __init__(self, num_processes=None, **kwargs):
"""Create a multiprocessing client.
Parameters
Expand All @@ -27,7 +27,8 @@ def __init__(self, num_processes=None):
Number of worker processes to use. Defaults to os.cpu_count().
"""
self.pool = multiprocessing.Pool(processes=num_processes)
num_processes = num_processes or kwargs.pop('processes', None)
self.pool = multiprocessing.Pool(processes=num_processes, **kwargs)

self.tasks = {}
self._id_counter = itertools.count()
Expand Down
2 changes: 1 addition & 1 deletion elfi/clients/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Client(elfi.client.ClientBase):
Responsible for sending computational graphs to be executed in an Executor
"""

def __init__(self):
def __init__(self, **kwargs):
"""Create a native client."""
self.tasks = {}
self._ids = itertools.count()
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,15 @@ def test_batch_handler(simple_model):
assert not np.array_equal(out0['k2'], out1['k2'])


def test_multiprocessing_kwargs(simple_model):
m = simple_model
num_proc = 2
elfi.set_client('multiprocessing', num_processes=num_proc)
rej = elfi.Rejection(m['k1'])
assert rej.client.num_cores == num_proc

elfi.set_client('multiprocessing', processes=num_proc)
rej = elfi.Rejection(m['k1'])
assert rej.client.num_cores == num_proc

# TODO: add testing that client is cleared from tasks after they are retrieved

0 comments on commit 043c87e

Please sign in to comment.