Permalink
Browse files

added exec_key and fixed client.shutdown

  • Loading branch information...
1 parent 9d70cce commit 1160d9f8398ae7dc37273a3f4d3a1d5a3e4c05d6 @minrk minrk committed Nov 20, 2010
@@ -12,6 +12,7 @@
from __future__ import print_function
+import os
import time
from pprint import pprint
@@ -139,19 +140,30 @@ class Client(object):
A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
If keyfile or password is specified, and this is not, it will default to
the ip given in addr.
- keyfile : str; path to public key file
+ sshkey : str; path to public ssh key file
This specifies a key to be used in ssh login, default None.
Regular default ssh keys will be used without specifying this argument.
password : str;
Your ssh password to sshserver. Note that if this is left None,
you will be prompted for it if passwordless key based login is unavailable.
+ #------- exec authentication args -------
+ # If even localhost is untrusted, you can have some protection against
+ # unauthorized execution by using a key. Messages are still sent
+ # as cleartext, so if someone can snoop your loopback traffic this will
+ # not help anything.
+
+ exec_key : str
+ an authentication key or file containing a key
+ default: None
+
+
Attributes
----------
ids : set of int engine IDs
requesting the ids attribute always synchronizes
the registration state. To request ids without synchronization,
- use semi-private _ids.
+ use semi-private _ids attributes.
history : list of msg_ids
a list of msg_ids, keeping track of all the execution
@@ -175,7 +187,7 @@ class Client(object):
barrier : wait on one or more msg_ids
- execution methods: apply/apply_bound/apply_to/applu_bount
+ execution methods: apply/apply_bound/apply_to/apply_bound
legacy: execute, run
query methods: queue_status, get_result, purge
@@ -202,26 +214,32 @@ class Client(object):
debug = False
def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
- sshserver=None, keyfile=None, password=None, paramiko=None):
+ sshserver=None, sshkey=None, password=None, paramiko=None,
+ exec_key=None,):
if context is None:
context = zmq.Context()
self.context = context
self._addr = addr
- self._ssh = bool(sshserver or keyfile or password)
+ self._ssh = bool(sshserver or sshkey or password)
if self._ssh and sshserver is None:
# default to the same
sshserver = addr.split('://')[1].split(':')[0]
if self._ssh and password is None:
- if tunnel.try_passwordless_ssh(sshserver, keyfile, paramiko):
+ if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
password=False
else:
password = getpass("SSH Password for %s: "%sshserver)
- ssh_kwargs = dict(keyfile=keyfile, password=password, paramiko=paramiko)
-
+ ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
+
+ if os.path.isfile(exec_key):
+ arg = 'keyfile'
+ else:
+ arg = 'key'
+ key_arg = {arg:exec_key}
if username is None:
- self.session = ss.StreamSession()
+ self.session = ss.StreamSession(**key_arg)
else:
- self.session = ss.StreamSession(username)
+ self.session = ss.StreamSession(username, **key_arg)
self._registration_socket = self.context.socket(zmq.XREQ)
self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
if self._ssh:
@@ -536,11 +554,12 @@ def abort(self, msg_ids = None, targets=None, block=None):
@spinfirst
@defaultblock
- def kill(self, targets=None, block=None):
+ def shutdown(self, targets=None, restart=False, block=None):
"""Terminates one or more engine processes."""
targets = self._build_targets(targets)[0]
for t in targets:
- self.session.send(self._control_socket, 'kill_request', content={},ident=t)
+ self.session.send(self._control_socket, 'shutdown_request',
+ content={'restart':restart},ident=t)
error = False
if self.block:
for i in range(len(targets)):
@@ -15,6 +15,7 @@
#-----------------------------------------------------------------------------
from __future__ import print_function
+import os
from datetime import datetime
import logging
@@ -28,7 +29,7 @@
from streamsession import Message, wrap_exception
from entry_point import (make_base_argument_parser, select_random_ports, split_ports,
- connect_logger, parse_url, signal_children)
+ connect_logger, parse_url, signal_children, generate_exec_key)
#-----------------------------------------------------------------------------
# Code
@@ -283,13 +284,12 @@ def dispatch_register_request(self, msg):
logger.debug("registration::dispatch_register_request(%s)"%msg)
idents,msg = self.session.feed_identities(msg)
if not idents:
- logger.error("Bad Queue Message: %s"%msg)
+ logger.error("Bad Queue Message: %s"%msg, exc_info=True)
return
try:
msg = self.session.unpack_message(msg,content=True)
- except Exception as e:
- logger.error("registration::got bad registration message: %s"%msg)
- raise e
+ except:
+ logger.error("registration::got bad registration message: %s"%msg, exc_info=True)
return
msg_type = msg['msg_type']
@@ -326,7 +326,7 @@ def dispatch_client_msg(self, msg):
msg = self.session.unpack_message(msg, content=True)
except:
content = wrap_exception()
- logger.error("Bad Client Message: %s"%msg)
+ logger.error("Bad Client Message: %s"%msg, exc_info=True)
self.session.send(self.clientele, "controller_error", ident=client_id,
content=content)
return
@@ -340,7 +340,7 @@ def dispatch_client_msg(self, msg):
assert handler is not None, "Bad Message Type: %s"%msg_type
except:
content = wrap_exception()
- logger.error("Bad Message Type: %s"%msg_type)
+ logger.error("Bad Message Type: %s"%msg_type, exc_info=True)
self.session.send(self.clientele, "controller_error", ident=client_id,
content=content)
return
@@ -390,7 +390,7 @@ def save_queue_request(self, idents, msg):
try:
msg = self.session.unpack_message(msg, content=False)
except:
- logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg))
+ logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
return
eid = self.by_ident.get(queue_id, None)
@@ -417,7 +417,7 @@ def save_queue_result(self, idents, msg):
msg = self.session.unpack_message(msg, content=False)
except:
logger.error("queue::engine %r sent invalid message to %r: %s"%(
- queue_id,client_id, msg))
+ queue_id,client_id, msg), exc_info=True)
return
eid = self.by_ident.get(queue_id, None)
@@ -448,7 +448,7 @@ def save_task_request(self, idents, msg):
msg = self.session.unpack_message(msg, content=False)
except:
logger.error("task::client %r sent invalid task message: %s"%(
- client_id, msg))
+ client_id, msg), exc_info=True)
return
header = msg['header']
@@ -871,7 +871,11 @@ def main():
n = ZMQStream(ctx.socket(zmq.PUB), loop)
nport = bind_port(n, args.ip, args.notice)
- thesession = session.StreamSession(username=args.ident or "controller")
+ ### Key File ###
+ if args.execkey and not os.path.isfile(args.execkey):
+ generate_exec_key(args.execkey)
+
+ thesession = session.StreamSession(username=args.ident or "controller", keyfile=args.execkey)
### build and launch the queues ###
@@ -40,7 +40,7 @@ class Engine(object):
heart=None
kernel=None
- def __init__(self, context, loop, session, registrar, client, ident=None, heart_id=None):
+ def __init__(self, context, loop, session, registrar, client=None, ident=None):
self.context = context
self.loop = loop
self.session = session
@@ -53,6 +53,7 @@ def register(self):
content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
self.registrar.on_recv(self.complete_registration)
+ # print (self.session.key)
self.session.send(self.registrar, "registration_request",content=content)
def complete_registration(self, msg):
@@ -77,9 +78,8 @@ def complete_registration(self, msg):
sub.on_recv(lambda *a: None)
port = sub.bind_to_random_port("tcp://%s"%LOCALHOST)
iopub_addr = "tcp://%s:%i"%(LOCALHOST,12345)
-
make_kernel(self.ident, control_addr, shell_addrs, iopub_addr, hb_addrs,
- client_addr=None, loop=self.loop, context=self.context)
+ client_addr=None, loop=self.loop, context=self.context, key=self.session.key)
else:
# logger.error("Registration Failed: %s"%msg)
@@ -111,7 +111,8 @@ def main():
iface="%s://%s"%(args.transport,args.ip)+':%i'
loop = ioloop.IOLoop.instance()
- session = StreamSession()
+ session = StreamSession(keyfile=args.execkey)
+ # print (session.key)
ctx = zmq.Context()
# setup logging
@@ -124,7 +125,7 @@ def main():
reg = ctx.socket(zmq.PAIR)
reg.connect(reg_conn)
reg = zmqstream.ZMQStream(reg, loop)
- client = Client(reg_conn)
+ client = None
e = Engine(ctx, loop, session, reg, client, args.ident)
dc = ioloop.DelayedCallback(e.start, 100, loop)
@@ -7,6 +7,7 @@
import atexit
import sys
import os
+import stat
import socket
from subprocess import Popen, PIPE
from signal import signal, SIGINT, SIGABRT, SIGTERM
@@ -33,7 +34,7 @@ def split_ports(s, n):
return ports
def select_random_ports(n):
- """Selects and return n random ports that are open."""
+ """Selects and return n random ports that are available."""
ports = []
for i in xrange(n):
sock = socket.socket()
@@ -46,6 +47,7 @@ def select_random_ports(n):
return ports
def parse_url(args):
+ """Ensure args.url contains full transport://interface:port"""
if args.url:
iface = args.url.split('://',1)
if len(args) == 2:
@@ -57,13 +59,25 @@ def parse_url(args):
args.url = "%s://%s:%i"%(args.transport, args.ip,args.regport)
def signal_children(children):
+ """Relay interupt/term signals to children, for more solid process cleanup."""
def terminate_children(sig, frame):
for child in children:
child.terminate()
# sys.exit(sig)
for sig in (SIGINT, SIGABRT, SIGTERM):
signal(sig, terminate_children)
+def generate_exec_key(keyfile):
+ import uuid
+ newkey = str(uuid.uuid4())
+ with open(keyfile, 'w') as f:
+ # f.write('ipython-key ')
+ f.write(newkey)
+ # set user-only RW permissions (0600)
+ # this will have no effect on Windows
+ os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
+
+
def make_base_argument_parser():
""" Creates an ArgumentParser for the generic arguments supported by all
ipcluster entry points.
@@ -86,6 +100,8 @@ def make_base_argument_parser():
help='set the message format method [default: json]')
parser.add_argument('--url', type=str,
help='set transport,ip,regport in one arg, e.g. tcp://127.0.0.1:10101')
+ parser.add_argument('--execkey', type=str,
+ help="File containing key for authenticating requests.")
return parser
@@ -65,7 +65,7 @@ def main():
controller_args = strip_args([('--n','-n')])
engine_args = filter_args(['--url', '--regport', '--logport', '--ip',
- '--transport','--loglevel','--packer'])+['--ident']
+ '--transport','--loglevel','--packer', '--execkey'])+['--ident']
controller = launch_process('controller', controller_args)
for i in range(10):
@@ -127,17 +127,21 @@ def shutdown_request(self, stream, ident, parent):
"""kill ourself. This should really be handled in an external process"""
self.abort_queues()
content = dict(parent['content'])
- msg = self.session.send(self.reply_socket, 'shutdown_reply',
- content, parent, ident)
- msg = self.session.send(self.pub_socket, 'shutdown_reply',
- content, parent, ident)
+ msg = self.session.send(stream, 'shutdown_reply',
+ content=content, parent=parent, ident=ident)
+ # msg = self.session.send(self.pub_socket, 'shutdown_reply',
+ # content, parent, ident)
# print >> sys.__stdout__, msg
time.sleep(0.1)
sys.exit(0)
def dispatch_control(self, msg):
idents,msg = self.session.feed_identities(msg, copy=False)
- msg = self.session.unpack_message(msg, content=True, copy=False)
+ try:
+ msg = self.session.unpack_message(msg, content=True, copy=False)
+ except:
+ logger.error("Invalid Message", exc_info=True)
+ return
header = msg['header']
msg_id = header['msg_id']
@@ -313,7 +317,12 @@ def apply_request(self, stream, ident, parent):
def dispatch_queue(self, stream, msg):
self.control_stream.flush()
idents,msg = self.session.feed_identities(msg, copy=False)
- msg = self.session.unpack_message(msg, content=True, copy=False)
+ try:
+ msg = self.session.unpack_message(msg, content=True, copy=False)
+ except:
+ logger.error("Invalid Message", exc_info=True)
+ return
+
header = msg['header']
msg_id = header['msg_id']
@@ -367,14 +376,15 @@ def start(self):
# time.sleep(1e-3)
def make_kernel(identity, control_addr, shell_addrs, iopub_addr, hb_addrs,
- client_addr=None, loop=None, context=None):
+ client_addr=None, loop=None, context=None, key=None):
# create loop, context, and session:
if loop is None:
loop = ioloop.IOLoop.instance()
if context is None:
context = zmq.Context()
c = context
- session = StreamSession()
+ session = StreamSession(key=key)
+ # print (session.key)
print (control_addr, shell_addrs, iopub_addr, hb_addrs)
# create Control Stream
Oops, something went wrong.

0 comments on commit 1160d9f

Please sign in to comment.