Permalink
Browse files

persist connection data to disk as json

  • Loading branch information...
1 parent ba75686 commit 68dde4349effa6378c0a8b94a0313f0857beabf5 @minrk minrk committed Feb 11, 2011
@@ -15,22 +15,26 @@
from getpass import getpass
from pprint import pprint
from datetime import datetime
+import json
+pjoin = os.path.join
import zmq
from zmq.eventloop import ioloop, zmqstream
+from IPython.utils.path import get_ipython_dir
from IPython.external.decorator import decorator
from IPython.zmq import tunnel
import streamsession as ss
+from clusterdir import ClusterDir, ClusterDirError
# from remotenamespace import RemoteNamespace
from view import DirectView, LoadBalancedView
from dependency import Dependency, depend, require
import error
import map as Map
from asyncresult import AsyncResult, AsyncMapResult
from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
-from util import ReverseDict
+from util import ReverseDict, disambiguate_url, validate_url
#--------------------------------------------------------------------------
# helpers for implementing old MEC API via client.apply
@@ -231,7 +235,6 @@ class Client(object):
_connected=False
_ssh=False
_engines=None
- _addr='tcp://127.0.0.1:10101'
_registration_socket=None
_query_socket=None
_control_socket=None
@@ -246,25 +249,59 @@ class Client(object):
debug = False
targets = None
- def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
+ def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
+ context=None, username=None, debug=False, exec_key=None,
sshserver=None, sshkey=None, password=None, paramiko=None,
- exec_key=None,):
+ ):
if context is None:
context = zmq.Context()
self.context = context
self.targets = 'all'
- self._addr = addr
+
+ self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
+ if self._cd is not None:
+ if url_or_file is None:
+ url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
+ assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
+ " Please specify at least one of url_or_file or profile."
+
+ try:
+ validate_url(url_or_file)
+ except AssertionError:
+ if not os.path.exists(url_or_file):
+ if self._cd:
+ url_or_file = os.path.join(self._cd.security_dir, url_or_file)
+ assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
+ with open(url_or_file) as f:
+ cfg = json.loads(f.read())
+ else:
+ cfg = {'url':url_or_file}
+
+ # sync defaults from args, json:
+ if sshserver:
+ cfg['ssh'] = sshserver
+ if exec_key:
+ cfg['exec_key'] = exec_key
+ exec_key = cfg['exec_key']
+ sshserver=cfg['ssh']
+ url = cfg['url']
+ location = cfg.setdefault('location', None)
+ cfg['url'] = disambiguate_url(cfg['url'], location)
+ url = cfg['url']
+
+ self._config = cfg
+
+
self._ssh = bool(sshserver or sshkey or password)
if self._ssh and sshserver is None:
- # default to the same
- sshserver = addr.split('://')[1].split(':')[0]
+ # default to ssh via localhost
+ sshserver = url.split('://')[1].split(':')[0]
if self._ssh and password is None:
if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
password=False
else:
password = getpass("SSH Password for %s: "%sshserver)
ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
-
if exec_key is not None and os.path.isfile(exec_key):
arg = 'keyfile'
else:
@@ -277,9 +314,9 @@ def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, de
self._registration_socket = self.context.socket(zmq.XREQ)
self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
if self._ssh:
- tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
+ tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs)
else:
- self._registration_socket.connect(addr)
+ self._registration_socket.connect(url)
self._engines = ReverseDict()
self._ids = set()
self.outstanding=set()
@@ -297,6 +334,23 @@ def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, de
self._connect(sshserver, ssh_kwargs)
+ def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
+ if ipython_dir is None:
+ ipython_dir = get_ipython_dir()
+ if cluster_dir is not None:
+ try:
+ self._cd = ClusterDir.find_cluster_dir(cluster_dir)
+ except ClusterDirError:
+ pass
+ elif profile is not None:
+ try:
+ self._cd = ClusterDir.find_cluster_dir_by_profile(
+ ipython_dir, profile)
+ except ClusterDirError:
+ pass
+ else:
+ self._cd = None
+
@property
def ids(self):
"""Always up to date ids property."""
@@ -332,11 +386,12 @@ def _connect(self, sshserver, ssh_kwargs):
return
self._connected=True
- def connect_socket(s, addr):
+ def connect_socket(s, url):
+ url = disambiguate_url(url, self._config['location'])
if self._ssh:
- return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
+ return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
else:
- return s.connect(addr)
+ return s.connect(url)
self.session.send(self._registration_socket, 'connection_request')
idents,msg = self.session.recv(self._registration_socket,mode=0)
@@ -902,7 +957,7 @@ def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None):
queues,targets = self._build_targets(targets)
- subheader = dict(after=after, follow=follow)
+ subheader = {}
content = dict(bound=bound)
bufs = ss.pack_apply_message(f,args,kwargs)
@@ -15,10 +15,11 @@
# internal
from IPython.config.configurable import Configurable
-from IPython.utils.traitlets import Instance, Str, Dict, Int, Type
+from IPython.utils.traitlets import Instance, Str, Dict, Int, Type, CFloat
# from IPython.utils.localinterfaces import LOCALHOST
from factory import RegistrationFactory
+from util import disambiguate_url
from streamsession import Message
from streamkernel import Kernel
@@ -35,6 +36,8 @@ class EngineFactory(RegistrationFactory):
user_ns=Dict(config=True)
out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True)
display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True)
+ location=Str(config=True)
+ timeout=CFloat(2,config=True)
# not configurable:
id=Int(allow_none=True)
@@ -62,6 +65,7 @@ def register(self):
def complete_registration(self, msg):
# print msg
+ self._abort_dc.stop()
ctx = self.context
loop = self.loop
identity = self.ident
@@ -83,20 +87,20 @@ def complete_registration(self, msg):
for addr in shell_addrs:
stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop)
stream.setsockopt(zmq.IDENTITY, identity)
- stream.connect(addr)
+ stream.connect(disambiguate_url(addr, self.location))
shell_streams.append(stream)
# control stream:
control_addr = str(msg.content.control)
control_stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop)
control_stream.setsockopt(zmq.IDENTITY, identity)
- control_stream.connect(control_addr)
+ control_stream.connect(disambiguate_url(control_addr, self.location))
# create iopub stream:
iopub_addr = msg.content.iopub
iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
iopub_stream.setsockopt(zmq.IDENTITY, identity)
- iopub_stream.connect(iopub_addr)
+ iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
# launch heartbeat
hb_addrs = msg.content.heartbeat
@@ -116,25 +120,28 @@ def complete_registration(self, msg):
control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
loop=loop, user_ns = self.user_ns, logname=self.log.name)
self.kernel.start()
-
+ hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
# ioloop.DelayedCallback(heart.start, 1000, self.loop).start()
heart.start()
else:
- self.log.error("Registration Failed: %s"%msg)
+ self.log.fatal("Registration Failed: %s"%msg)
raise Exception("Registration Failed: %s"%msg)
self.log.info("Completed registration with id %i"%self.id)
- def unregister(self):
+ def abort(self):
+ self.log.fatal("Registration timed out")
self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
time.sleep(1)
- sys.exit(0)
+ sys.exit(255)
def start(self):
dc = ioloop.DelayedCallback(self.register, 0, self.loop)
dc.start()
+ self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
+ self._abort_dc.start()
@@ -46,7 +46,7 @@ class SessionFactory(LoggingFactory):
def _ident_default(self):
return str(uuid.uuid4())
username = Str(os.environ.get('USER','username'),config=True)
- exec_key = CUnicode('',config=True)
+ exec_key = CStr('',config=True)
# not configurable:
context = Instance('zmq.Context', (), {})
session = Instance('IPython.zmq.parallel.streamsession.StreamSession')
@@ -57,9 +57,7 @@ def _loop_default(self):
def __init__(self, **kwargs):
super(SessionFactory, self).__init__(**kwargs)
-
- keyfile = self.exec_key or None
-
+ exec_key = self.exec_key or None
# set the packers:
if not self.packer:
packer_f = unpacker_f = None
@@ -74,7 +72,7 @@ def __init__(self, **kwargs):
unpacker_f = import_item(self.unpacker)
# construct the session
- self.session = ss.StreamSession(self.username, self.ident, packer=packer_f, unpacker=unpacker_f, keyfile=keyfile)
+ self.session = ss.StreamSession(self.username, self.ident, packer=packer_f, unpacker=unpacker_f, key=exec_key)
class RegistrationFactory(SessionFactory):
@@ -85,8 +83,8 @@ class RegistrationFactory(SessionFactory):
ip = Str('127.0.0.1', config=True)
regport = Instance(int, config=True)
def _regport_default(self):
- return 10101
- # return select_random_ports(1)[0]
+ # return 10101
+ return select_random_ports(1)[0]
def __init__(self, **kwargs):
super(RegistrationFactory, self).__init__(**kwargs)
Oops, something went wrong.

0 comments on commit 68dde43

Please sign in to comment.