@@ -27,8 +27,9 @@ class AfarException(Exception):
2727
2828
2929class Where :
30- def __init__ (self , where , submit_kwargs = None ):
30+ def __init__ (self , where , client = None , submit_kwargs = None ):
3131 self .where = where
32+ self .client = client
3233 self .submit_kwargs = submit_kwargs
3334
3435 def __enter__ (self ):
@@ -37,8 +38,8 @@ def __enter__(self):
3738 def __exit__ (self , exc_type , exc_value , exc_traceback ): # pragma: no cover
3839 return False
3940
40- def __call__ (self , ** submit_kwargs ):
41- return Where (self .where , submit_kwargs )
41+ def __call__ (self , * , client = None , * *submit_kwargs ):
42+ return Where (self .where , client , submit_kwargs )
4243
4344
4445remotely = Where ("remotely" )
@@ -166,9 +167,11 @@ def _exit(self, exc_type, exc_value, exc_traceback):
166167 where = exc_value .args [0 ]
167168 self ._where = where .where
168169 submit_kwargs = where .submit_kwargs or {}
170+ client = where .client
169171 elif issubclass (exc_type , NameError ) and exc_value .args [0 ] in _errors_to_locations :
170172 self ._where = _errors_to_locations [exc_value .args [0 ]]
171173 submit_kwargs = {}
174+ client = None
172175 else :
173176 # The exception is valid
174177 return False
@@ -189,7 +192,8 @@ def _exit(self, exc_type, exc_value, exc_traceback):
189192 display_expr = self ._magic_func ._display_expr
190193
191194 if self ._where == "remotely" :
192- client = distributed .client ._get_global_client ()
195+ if client is None :
196+ client = distributed .client ._get_global_client ()
193197 remote_dict = client .submit (run_afar , self ._magic_func , names , futures , ** submit_kwargs )
194198 if display_expr :
195199 repr_val = client .submit (
0 commit comments