Skip to content

Commit

Permalink
added exec_key and fixed client.shutdown
Browse files Browse the repository at this point in the history
  • Loading branch information
minrk committed Apr 8, 2011
1 parent 9d70cce commit 1160d9f
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 47 deletions.
43 changes: 31 additions & 12 deletions IPython/zmq/parallel/client.py
Expand Up @@ -12,6 +12,7 @@


from __future__ import print_function from __future__ import print_function


import os
import time import time
from pprint import pprint from pprint import pprint


Expand Down Expand Up @@ -139,19 +140,30 @@ class Client(object):
A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port' A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
If keyfile or password is specified, and this is not, it will default to If keyfile or password is specified, and this is not, it will default to
the ip given in addr. the ip given in addr.
keyfile : str; path to public key file sshkey : str; path to public ssh key file
This specifies a key to be used in ssh login, default None. This specifies a key to be used in ssh login, default None.
Regular default ssh keys will be used without specifying this argument. Regular default ssh keys will be used without specifying this argument.
password : str; password : str;
Your ssh password to sshserver. Note that if this is left None, Your ssh password to sshserver. Note that if this is left None,
you will be prompted for it if passwordless key based login is unavailable. you will be prompted for it if passwordless key based login is unavailable.
#------- exec authentication args -------
# If even localhost is untrusted, you can have some protection against
# unauthorized execution by using a key. Messages are still sent
# as cleartext, so if someone can snoop your loopback traffic this will
# not help anything.
exec_key : str
an authentication key or file containing a key
default: None
Attributes Attributes
---------- ----------
ids : set of int engine IDs ids : set of int engine IDs
requesting the ids attribute always synchronizes requesting the ids attribute always synchronizes
the registration state. To request ids without synchronization, the registration state. To request ids without synchronization,
use semi-private _ids. use semi-private _ids attributes.
history : list of msg_ids history : list of msg_ids
a list of msg_ids, keeping track of all the execution a list of msg_ids, keeping track of all the execution
Expand All @@ -175,7 +187,7 @@ class Client(object):
barrier : wait on one or more msg_ids barrier : wait on one or more msg_ids
execution methods: apply/apply_bound/apply_to/applu_bount execution methods: apply/apply_bound/apply_to/apply_bound
legacy: execute, run legacy: execute, run
query methods: queue_status, get_result, purge query methods: queue_status, get_result, purge
Expand All @@ -202,26 +214,32 @@ class Client(object):
debug = False debug = False


def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False, def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
sshserver=None, keyfile=None, password=None, paramiko=None): sshserver=None, sshkey=None, password=None, paramiko=None,
exec_key=None,):
if context is None: if context is None:
context = zmq.Context() context = zmq.Context()
self.context = context self.context = context
self._addr = addr self._addr = addr
self._ssh = bool(sshserver or keyfile or password) self._ssh = bool(sshserver or sshkey or password)
if self._ssh and sshserver is None: if self._ssh and sshserver is None:
# default to the same # default to the same
sshserver = addr.split('://')[1].split(':')[0] sshserver = addr.split('://')[1].split(':')[0]
if self._ssh and password is None: if self._ssh and password is None:
if tunnel.try_passwordless_ssh(sshserver, keyfile, paramiko): if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
password=False password=False
else: else:
password = getpass("SSH Password for %s: "%sshserver) password = getpass("SSH Password for %s: "%sshserver)
ssh_kwargs = dict(keyfile=keyfile, password=password, paramiko=paramiko) ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)


if os.path.isfile(exec_key):
arg = 'keyfile'
else:
arg = 'key'
key_arg = {arg:exec_key}
if username is None: if username is None:
self.session = ss.StreamSession() self.session = ss.StreamSession(**key_arg)
else: else:
self.session = ss.StreamSession(username) self.session = ss.StreamSession(username, **key_arg)
self._registration_socket = self.context.socket(zmq.XREQ) self._registration_socket = self.context.socket(zmq.XREQ)
self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session) self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
if self._ssh: if self._ssh:
Expand Down Expand Up @@ -536,11 +554,12 @@ def abort(self, msg_ids = None, targets=None, block=None):


@spinfirst @spinfirst
@defaultblock @defaultblock
def kill(self, targets=None, block=None): def shutdown(self, targets=None, restart=False, block=None):
"""Terminates one or more engine processes.""" """Terminates one or more engine processes."""
targets = self._build_targets(targets)[0] targets = self._build_targets(targets)[0]
for t in targets: for t in targets:
self.session.send(self._control_socket, 'kill_request', content={},ident=t) self.session.send(self._control_socket, 'shutdown_request',
content={'restart':restart},ident=t)
error = False error = False
if self.block: if self.block:
for i in range(len(targets)): for i in range(len(targets)):
Expand Down
26 changes: 15 additions & 11 deletions IPython/zmq/parallel/controller.py
Expand Up @@ -15,6 +15,7 @@
#----------------------------------------------------------------------------- #-----------------------------------------------------------------------------
from __future__ import print_function from __future__ import print_function


import os
from datetime import datetime from datetime import datetime
import logging import logging


Expand All @@ -28,7 +29,7 @@


from streamsession import Message, wrap_exception from streamsession import Message, wrap_exception
from entry_point import (make_base_argument_parser, select_random_ports, split_ports, from entry_point import (make_base_argument_parser, select_random_ports, split_ports,
connect_logger, parse_url, signal_children) connect_logger, parse_url, signal_children, generate_exec_key)


#----------------------------------------------------------------------------- #-----------------------------------------------------------------------------
# Code # Code
Expand Down Expand Up @@ -283,13 +284,12 @@ def dispatch_register_request(self, msg):
logger.debug("registration::dispatch_register_request(%s)"%msg) logger.debug("registration::dispatch_register_request(%s)"%msg)
idents,msg = self.session.feed_identities(msg) idents,msg = self.session.feed_identities(msg)
if not idents: if not idents:
logger.error("Bad Queue Message: %s"%msg) logger.error("Bad Queue Message: %s"%msg, exc_info=True)
return return
try: try:
msg = self.session.unpack_message(msg,content=True) msg = self.session.unpack_message(msg,content=True)
except Exception as e: except:
logger.error("registration::got bad registration message: %s"%msg) logger.error("registration::got bad registration message: %s"%msg, exc_info=True)
raise e
return return


msg_type = msg['msg_type'] msg_type = msg['msg_type']
Expand Down Expand Up @@ -326,7 +326,7 @@ def dispatch_client_msg(self, msg):
msg = self.session.unpack_message(msg, content=True) msg = self.session.unpack_message(msg, content=True)
except: except:
content = wrap_exception() content = wrap_exception()
logger.error("Bad Client Message: %s"%msg) logger.error("Bad Client Message: %s"%msg, exc_info=True)
self.session.send(self.clientele, "controller_error", ident=client_id, self.session.send(self.clientele, "controller_error", ident=client_id,
content=content) content=content)
return return
Expand All @@ -340,7 +340,7 @@ def dispatch_client_msg(self, msg):
assert handler is not None, "Bad Message Type: %s"%msg_type assert handler is not None, "Bad Message Type: %s"%msg_type
except: except:
content = wrap_exception() content = wrap_exception()
logger.error("Bad Message Type: %s"%msg_type) logger.error("Bad Message Type: %s"%msg_type, exc_info=True)
self.session.send(self.clientele, "controller_error", ident=client_id, self.session.send(self.clientele, "controller_error", ident=client_id,
content=content) content=content)
return return
Expand Down Expand Up @@ -390,7 +390,7 @@ def save_queue_request(self, idents, msg):
try: try:
msg = self.session.unpack_message(msg, content=False) msg = self.session.unpack_message(msg, content=False)
except: except:
logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg)) logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
return return


eid = self.by_ident.get(queue_id, None) eid = self.by_ident.get(queue_id, None)
Expand All @@ -417,7 +417,7 @@ def save_queue_result(self, idents, msg):
msg = self.session.unpack_message(msg, content=False) msg = self.session.unpack_message(msg, content=False)
except: except:
logger.error("queue::engine %r sent invalid message to %r: %s"%( logger.error("queue::engine %r sent invalid message to %r: %s"%(
queue_id,client_id, msg)) queue_id,client_id, msg), exc_info=True)
return return


eid = self.by_ident.get(queue_id, None) eid = self.by_ident.get(queue_id, None)
Expand Down Expand Up @@ -448,7 +448,7 @@ def save_task_request(self, idents, msg):
msg = self.session.unpack_message(msg, content=False) msg = self.session.unpack_message(msg, content=False)
except: except:
logger.error("task::client %r sent invalid task message: %s"%( logger.error("task::client %r sent invalid task message: %s"%(
client_id, msg)) client_id, msg), exc_info=True)
return return


header = msg['header'] header = msg['header']
Expand Down Expand Up @@ -871,7 +871,11 @@ def main():
n = ZMQStream(ctx.socket(zmq.PUB), loop) n = ZMQStream(ctx.socket(zmq.PUB), loop)
nport = bind_port(n, args.ip, args.notice) nport = bind_port(n, args.ip, args.notice)


thesession = session.StreamSession(username=args.ident or "controller") ### Key File ###
if args.execkey and not os.path.isfile(args.execkey):
generate_exec_key(args.execkey)

thesession = session.StreamSession(username=args.ident or "controller", keyfile=args.execkey)


### build and launch the queues ### ### build and launch the queues ###


Expand Down
11 changes: 6 additions & 5 deletions IPython/zmq/parallel/engine.py
Expand Up @@ -40,7 +40,7 @@ class Engine(object):
heart=None heart=None
kernel=None kernel=None


def __init__(self, context, loop, session, registrar, client, ident=None, heart_id=None): def __init__(self, context, loop, session, registrar, client=None, ident=None):
self.context = context self.context = context
self.loop = loop self.loop = loop
self.session = session self.session = session
Expand All @@ -53,6 +53,7 @@ def register(self):


content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident) content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
self.registrar.on_recv(self.complete_registration) self.registrar.on_recv(self.complete_registration)
# print (self.session.key)
self.session.send(self.registrar, "registration_request",content=content) self.session.send(self.registrar, "registration_request",content=content)


def complete_registration(self, msg): def complete_registration(self, msg):
Expand All @@ -77,9 +78,8 @@ def complete_registration(self, msg):
sub.on_recv(lambda *a: None) sub.on_recv(lambda *a: None)
port = sub.bind_to_random_port("tcp://%s"%LOCALHOST) port = sub.bind_to_random_port("tcp://%s"%LOCALHOST)
iopub_addr = "tcp://%s:%i"%(LOCALHOST,12345) iopub_addr = "tcp://%s:%i"%(LOCALHOST,12345)

make_kernel(self.ident, control_addr, shell_addrs, iopub_addr, hb_addrs, make_kernel(self.ident, control_addr, shell_addrs, iopub_addr, hb_addrs,
client_addr=None, loop=self.loop, context=self.context) client_addr=None, loop=self.loop, context=self.context, key=self.session.key)


else: else:
# logger.error("Registration Failed: %s"%msg) # logger.error("Registration Failed: %s"%msg)
Expand Down Expand Up @@ -111,7 +111,8 @@ def main():
iface="%s://%s"%(args.transport,args.ip)+':%i' iface="%s://%s"%(args.transport,args.ip)+':%i'


loop = ioloop.IOLoop.instance() loop = ioloop.IOLoop.instance()
session = StreamSession() session = StreamSession(keyfile=args.execkey)
# print (session.key)
ctx = zmq.Context() ctx = zmq.Context()


# setup logging # setup logging
Expand All @@ -124,7 +125,7 @@ def main():
reg = ctx.socket(zmq.PAIR) reg = ctx.socket(zmq.PAIR)
reg.connect(reg_conn) reg.connect(reg_conn)
reg = zmqstream.ZMQStream(reg, loop) reg = zmqstream.ZMQStream(reg, loop)
client = Client(reg_conn) client = None


e = Engine(ctx, loop, session, reg, client, args.ident) e = Engine(ctx, loop, session, reg, client, args.ident)
dc = ioloop.DelayedCallback(e.start, 100, loop) dc = ioloop.DelayedCallback(e.start, 100, loop)
Expand Down
18 changes: 17 additions & 1 deletion IPython/zmq/parallel/entry_point.py
Expand Up @@ -7,6 +7,7 @@
import atexit import atexit
import sys import sys
import os import os
import stat
import socket import socket
from subprocess import Popen, PIPE from subprocess import Popen, PIPE
from signal import signal, SIGINT, SIGABRT, SIGTERM from signal import signal, SIGINT, SIGABRT, SIGTERM
Expand All @@ -33,7 +34,7 @@ def split_ports(s, n):
return ports return ports


def select_random_ports(n): def select_random_ports(n):
"""Selects and return n random ports that are open.""" """Selects and return n random ports that are available."""
ports = [] ports = []
for i in xrange(n): for i in xrange(n):
sock = socket.socket() sock = socket.socket()
Expand All @@ -46,6 +47,7 @@ def select_random_ports(n):
return ports return ports


def parse_url(args): def parse_url(args):
"""Ensure args.url contains full transport://interface:port"""
if args.url: if args.url:
iface = args.url.split('://',1) iface = args.url.split('://',1)
if len(args) == 2: if len(args) == 2:
Expand All @@ -57,13 +59,25 @@ def parse_url(args):
args.url = "%s://%s:%i"%(args.transport, args.ip,args.regport) args.url = "%s://%s:%i"%(args.transport, args.ip,args.regport)


def signal_children(children): def signal_children(children):
"""Relay interupt/term signals to children, for more solid process cleanup."""
def terminate_children(sig, frame): def terminate_children(sig, frame):
for child in children: for child in children:
child.terminate() child.terminate()
# sys.exit(sig) # sys.exit(sig)
for sig in (SIGINT, SIGABRT, SIGTERM): for sig in (SIGINT, SIGABRT, SIGTERM):
signal(sig, terminate_children) signal(sig, terminate_children)


def generate_exec_key(keyfile):
import uuid
newkey = str(uuid.uuid4())
with open(keyfile, 'w') as f:
# f.write('ipython-key ')
f.write(newkey)
# set user-only RW permissions (0600)
# this will have no effect on Windows
os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)


def make_base_argument_parser(): def make_base_argument_parser():
""" Creates an ArgumentParser for the generic arguments supported by all """ Creates an ArgumentParser for the generic arguments supported by all
ipcluster entry points. ipcluster entry points.
Expand All @@ -86,6 +100,8 @@ def make_base_argument_parser():
help='set the message format method [default: json]') help='set the message format method [default: json]')
parser.add_argument('--url', type=str, parser.add_argument('--url', type=str,
help='set transport,ip,regport in one arg, e.g. tcp://127.0.0.1:10101') help='set transport,ip,regport in one arg, e.g. tcp://127.0.0.1:10101')
parser.add_argument('--execkey', type=str,
help="File containing key for authenticating requests.")


return parser return parser


Expand Down
2 changes: 1 addition & 1 deletion IPython/zmq/parallel/ipcluster.py
Expand Up @@ -65,7 +65,7 @@ def main():


controller_args = strip_args([('--n','-n')]) controller_args = strip_args([('--n','-n')])
engine_args = filter_args(['--url', '--regport', '--logport', '--ip', engine_args = filter_args(['--url', '--regport', '--logport', '--ip',
'--transport','--loglevel','--packer'])+['--ident'] '--transport','--loglevel','--packer', '--execkey'])+['--ident']


controller = launch_process('controller', controller_args) controller = launch_process('controller', controller_args)
for i in range(10): for i in range(10):
Expand Down
26 changes: 18 additions & 8 deletions IPython/zmq/parallel/streamkernel.py
Expand Up @@ -127,17 +127,21 @@ def shutdown_request(self, stream, ident, parent):
"""kill ourself. This should really be handled in an external process""" """kill ourself. This should really be handled in an external process"""
self.abort_queues() self.abort_queues()
content = dict(parent['content']) content = dict(parent['content'])
msg = self.session.send(self.reply_socket, 'shutdown_reply', msg = self.session.send(stream, 'shutdown_reply',
content, parent, ident) content=content, parent=parent, ident=ident)
msg = self.session.send(self.pub_socket, 'shutdown_reply', # msg = self.session.send(self.pub_socket, 'shutdown_reply',
content, parent, ident) # content, parent, ident)
# print >> sys.__stdout__, msg # print >> sys.__stdout__, msg
time.sleep(0.1) time.sleep(0.1)
sys.exit(0) sys.exit(0)


def dispatch_control(self, msg): def dispatch_control(self, msg):
idents,msg = self.session.feed_identities(msg, copy=False) idents,msg = self.session.feed_identities(msg, copy=False)
msg = self.session.unpack_message(msg, content=True, copy=False) try:
msg = self.session.unpack_message(msg, content=True, copy=False)
except:
logger.error("Invalid Message", exc_info=True)
return


header = msg['header'] header = msg['header']
msg_id = header['msg_id'] msg_id = header['msg_id']
Expand Down Expand Up @@ -313,7 +317,12 @@ def apply_request(self, stream, ident, parent):
def dispatch_queue(self, stream, msg): def dispatch_queue(self, stream, msg):
self.control_stream.flush() self.control_stream.flush()
idents,msg = self.session.feed_identities(msg, copy=False) idents,msg = self.session.feed_identities(msg, copy=False)
msg = self.session.unpack_message(msg, content=True, copy=False) try:
msg = self.session.unpack_message(msg, content=True, copy=False)
except:
logger.error("Invalid Message", exc_info=True)
return



header = msg['header'] header = msg['header']
msg_id = header['msg_id'] msg_id = header['msg_id']
Expand Down Expand Up @@ -367,14 +376,15 @@ def start(self):
# time.sleep(1e-3) # time.sleep(1e-3)


def make_kernel(identity, control_addr, shell_addrs, iopub_addr, hb_addrs, def make_kernel(identity, control_addr, shell_addrs, iopub_addr, hb_addrs,
client_addr=None, loop=None, context=None): client_addr=None, loop=None, context=None, key=None):
# create loop, context, and session: # create loop, context, and session:
if loop is None: if loop is None:
loop = ioloop.IOLoop.instance() loop = ioloop.IOLoop.instance()
if context is None: if context is None:
context = zmq.Context() context = zmq.Context()
c = context c = context
session = StreamSession() session = StreamSession(key=key)
# print (session.key)
print (control_addr, shell_addrs, iopub_addr, hb_addrs) print (control_addr, shell_addrs, iopub_addr, hb_addrs)


# create Control Stream # create Control Stream
Expand Down

0 comments on commit 1160d9f

Please sign in to comment.