Skip to content

Commit

Permalink
persist connection data to disk as json
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk committed Apr 8, 2011
1 parent ba75686 commit 68dde43
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 79 deletions.
83 changes: 69 additions & 14 deletions IPython/zmq/parallel/client.py
Expand Up @@ -15,22 +15,26 @@
from getpass import getpass
from pprint import pprint
from datetime import datetime
import json
pjoin = os.path.join

import zmq
from zmq.eventloop import ioloop, zmqstream

from IPython.utils.path import get_ipython_dir
from IPython.external.decorator import decorator
from IPython.zmq import tunnel

import streamsession as ss
from clusterdir import ClusterDir, ClusterDirError
# from remotenamespace import RemoteNamespace
from view import DirectView, LoadBalancedView
from dependency import Dependency, depend, require
import error
import map as Map
from asyncresult import AsyncResult, AsyncMapResult
from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
from util import ReverseDict
from util import ReverseDict, disambiguate_url, validate_url

#--------------------------------------------------------------------------
# helpers for implementing old MEC API via client.apply
Expand Down Expand Up @@ -231,7 +235,6 @@ class Client(object):
_connected=False
_ssh=False
_engines=None
_addr='tcp://127.0.0.1:10101'
_registration_socket=None
_query_socket=None
_control_socket=None
Expand All @@ -246,25 +249,59 @@ class Client(object):
debug = False
targets = None

def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
context=None, username=None, debug=False, exec_key=None,
sshserver=None, sshkey=None, password=None, paramiko=None,
exec_key=None,):
):
if context is None:
context = zmq.Context()
self.context = context
self.targets = 'all'
self._addr = addr

self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
if self._cd is not None:
if url_or_file is None:
url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
" Please specify at least one of url_or_file or profile."

try:
validate_url(url_or_file)
except AssertionError:
if not os.path.exists(url_or_file):
if self._cd:
url_or_file = os.path.join(self._cd.security_dir, url_or_file)
assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
with open(url_or_file) as f:
cfg = json.loads(f.read())
else:
cfg = {'url':url_or_file}

# sync defaults from args, json:
if sshserver:
cfg['ssh'] = sshserver
if exec_key:
cfg['exec_key'] = exec_key
exec_key = cfg['exec_key']
sshserver=cfg['ssh']
url = cfg['url']
location = cfg.setdefault('location', None)
cfg['url'] = disambiguate_url(cfg['url'], location)
url = cfg['url']

self._config = cfg


self._ssh = bool(sshserver or sshkey or password)
if self._ssh and sshserver is None:
# default to the same
sshserver = addr.split('://')[1].split(':')[0]
# default to ssh via localhost
sshserver = url.split('://')[1].split(':')[0]
if self._ssh and password is None:
if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
password=False
else:
password = getpass("SSH Password for %s: "%sshserver)
ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)

if exec_key is not None and os.path.isfile(exec_key):
arg = 'keyfile'
else:
Expand All @@ -277,9 +314,9 @@ def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, de
self._registration_socket = self.context.socket(zmq.XREQ)
self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
if self._ssh:
tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs)
else:
self._registration_socket.connect(addr)
self._registration_socket.connect(url)
self._engines = ReverseDict()
self._ids = set()
self.outstanding=set()
Expand All @@ -297,6 +334,23 @@ def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, de
self._connect(sshserver, ssh_kwargs)


def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
if ipython_dir is None:
ipython_dir = get_ipython_dir()
if cluster_dir is not None:
try:
self._cd = ClusterDir.find_cluster_dir(cluster_dir)
except ClusterDirError:
pass
elif profile is not None:
try:
self._cd = ClusterDir.find_cluster_dir_by_profile(
ipython_dir, profile)
except ClusterDirError:
pass
else:
self._cd = None

@property
def ids(self):
"""Always up to date ids property."""
Expand Down Expand Up @@ -332,11 +386,12 @@ def _connect(self, sshserver, ssh_kwargs):
return
self._connected=True

def connect_socket(s, addr):
def connect_socket(s, url):
url = disambiguate_url(url, self._config['location'])
if self._ssh:
return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
else:
return s.connect(addr)
return s.connect(url)

self.session.send(self._registration_socket, 'connection_request')
idents,msg = self.session.recv(self._registration_socket,mode=0)
Expand Down Expand Up @@ -902,7 +957,7 @@ def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None):

queues,targets = self._build_targets(targets)

subheader = dict(after=after, follow=follow)
subheader = {}
content = dict(bound=bound)
bufs = ss.pack_apply_message(f,args,kwargs)

Expand Down
23 changes: 15 additions & 8 deletions IPython/zmq/parallel/engine.py
Expand Up @@ -15,10 +15,11 @@

# internal
from IPython.config.configurable import Configurable
from IPython.utils.traitlets import Instance, Str, Dict, Int, Type
from IPython.utils.traitlets import Instance, Str, Dict, Int, Type, CFloat
# from IPython.utils.localinterfaces import LOCALHOST

from factory import RegistrationFactory
from util import disambiguate_url

from streamsession import Message
from streamkernel import Kernel
Expand All @@ -35,6 +36,8 @@ class EngineFactory(RegistrationFactory):
user_ns=Dict(config=True)
out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True)
display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True)
location=Str(config=True)
timeout=CFloat(2,config=True)

# not configurable:
id=Int(allow_none=True)
Expand Down Expand Up @@ -62,6 +65,7 @@ def register(self):

def complete_registration(self, msg):
# print msg
self._abort_dc.stop()
ctx = self.context
loop = self.loop
identity = self.ident
Expand All @@ -83,20 +87,20 @@ def complete_registration(self, msg):
for addr in shell_addrs:
stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop)
stream.setsockopt(zmq.IDENTITY, identity)
stream.connect(addr)
stream.connect(disambiguate_url(addr, self.location))
shell_streams.append(stream)

# control stream:
control_addr = str(msg.content.control)
control_stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop)
control_stream.setsockopt(zmq.IDENTITY, identity)
control_stream.connect(control_addr)
control_stream.connect(disambiguate_url(control_addr, self.location))

# create iopub stream:
iopub_addr = msg.content.iopub
iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
iopub_stream.setsockopt(zmq.IDENTITY, identity)
iopub_stream.connect(iopub_addr)
iopub_stream.connect(disambiguate_url(iopub_addr, self.location))

# launch heartbeat
hb_addrs = msg.content.heartbeat
Expand All @@ -116,25 +120,28 @@ def complete_registration(self, msg):
control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
loop=loop, user_ns = self.user_ns, logname=self.log.name)
self.kernel.start()

hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
# ioloop.DelayedCallback(heart.start, 1000, self.loop).start()
heart.start()


else:
self.log.error("Registration Failed: %s"%msg)
self.log.fatal("Registration Failed: %s"%msg)
raise Exception("Registration Failed: %s"%msg)

self.log.info("Completed registration with id %i"%self.id)


def unregister(self):
def abort(self):
self.log.fatal("Registration timed out")
self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
time.sleep(1)
sys.exit(0)
sys.exit(255)

def start(self):
dc = ioloop.DelayedCallback(self.register, 0, self.loop)
dc.start()
self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
self._abort_dc.start()

12 changes: 5 additions & 7 deletions IPython/zmq/parallel/factory.py
Expand Up @@ -46,7 +46,7 @@ class SessionFactory(LoggingFactory):
def _ident_default(self):
return str(uuid.uuid4())
username = Str(os.environ.get('USER','username'),config=True)
exec_key = CUnicode('',config=True)
exec_key = CStr('',config=True)
# not configurable:
context = Instance('zmq.Context', (), {})
session = Instance('IPython.zmq.parallel.streamsession.StreamSession')
Expand All @@ -57,9 +57,7 @@ def _loop_default(self):

def __init__(self, **kwargs):
super(SessionFactory, self).__init__(**kwargs)

keyfile = self.exec_key or None

exec_key = self.exec_key or None
# set the packers:
if not self.packer:
packer_f = unpacker_f = None
Expand All @@ -74,7 +72,7 @@ def __init__(self, **kwargs):
unpacker_f = import_item(self.unpacker)

# construct the session
self.session = ss.StreamSession(self.username, self.ident, packer=packer_f, unpacker=unpacker_f, keyfile=keyfile)
self.session = ss.StreamSession(self.username, self.ident, packer=packer_f, unpacker=unpacker_f, key=exec_key)


class RegistrationFactory(SessionFactory):
Expand All @@ -85,8 +83,8 @@ class RegistrationFactory(SessionFactory):
ip = Str('127.0.0.1', config=True)
regport = Instance(int, config=True)
def _regport_default(self):
return 10101
# return select_random_ports(1)[0]
# return 10101
return select_random_ports(1)[0]

def __init__(self, **kwargs):
super(RegistrationFactory, self).__init__(**kwargs)
Expand Down

0 comments on commit 68dde43

Please sign in to comment.