Skip to content

Commit

Permalink
Merge pull request ipython#1868 from minrk/ipc
Browse files Browse the repository at this point in the history
enable IPC transport for kernels

works with the qtconsole

Config is a bit clumsy, because the interpretation of 'ip' is actually a path when transport is IPC.

Notebook does not yet expose the option, because it's still not well integrated into the rest of the config universe.
  • Loading branch information
Carreau committed Sep 29, 2012
2 parents 268ef2a + f67b88e commit 25951fa
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 53 deletions.
34 changes: 28 additions & 6 deletions IPython/frontend/consoleapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import atexit
import json
import os
import shutil
import signal
import sys
import uuid
Expand All @@ -38,7 +39,7 @@
from IPython.utils.path import filefind
from IPython.utils.py3compat import str_to_bytes
from IPython.utils.traitlets import (
Dict, List, Unicode, CUnicode, Int, CBool, Any
Dict, List, Unicode, CUnicode, Int, CBool, Any, CaselessStrEnum
)
from IPython.zmq.ipkernel import (
flags as ipkernel_flags,
Expand Down Expand Up @@ -151,12 +152,27 @@ class IPythonConsoleApp(Configurable):
# create requested profiles by default, if they don't exist:
auto_create = CBool(True)
# connection info:
ip = Unicode(LOCALHOST, config=True,

transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)

ip = Unicode(config=True,
help="""Set the kernel\'s IP address [default localhost].
If the IP address is something other than localhost, then
Consoles on other machines will be able to connect
to the Kernel, so be careful!"""
)
def _ip_default(self):
if self.transport == 'tcp':
return LOCALHOST
else:
# this can fire early if ip is given,
# in which case our return value is meaningless
if not hasattr(self, 'profile_dir'):
return ''
ipcdir = os.path.join(self.profile_dir.security_dir, 'kernel-%s' % os.getpid())
os.makedirs(ipcdir)
atexit.register(lambda : shutil.rmtree(ipcdir))
return os.path.join(ipcdir, 'ipc')

sshserver = Unicode('', config=True,
help="""The SSH server to use to connect to the kernel.""")
Expand Down Expand Up @@ -256,10 +272,10 @@ def load_connection_file(self):
return
self.log.debug(u"Loading connection file %s", fname)
with open(fname) as f:
s = f.read()
cfg = json.loads(s)
if self.ip == LOCALHOST and 'ip' in cfg:
# not overridden by config or cl_args
cfg = json.load(f)

self.transport = cfg.get('transport', 'tcp')
if 'ip' in cfg:
self.ip = cfg['ip']
for channel in ('hb', 'shell', 'iopub', 'stdin'):
name = channel + '_port'
Expand All @@ -268,12 +284,17 @@ def load_connection_file(self):
setattr(self, name, cfg[name])
if 'key' in cfg:
self.config.Session.key = str_to_bytes(cfg['key'])


def init_ssh(self):
"""set up ssh tunnels, if needed."""
if not self.sshserver and not self.sshkey:
return

if self.transport != 'tcp':
self.log.error("Can only use ssh tunnels with TCP sockets, not %s", self.transport)
return

if self.sshkey and not self.sshserver:
# specifying just the key implies that we are connecting directly
self.sshserver = self.ip
Expand Down Expand Up @@ -326,6 +347,7 @@ def init_kernel_manager(self):

# Create a KernelManager and start a kernel.
self.kernel_manager = self.kernel_manager_class(
transport=self.transport,
ip=self.ip,
shell_port=self.shell_port,
iopub_port=self.iopub_port,
Expand Down
28 changes: 19 additions & 9 deletions IPython/zmq/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from parentpoller import ParentPollerWindows

def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
ip=LOCALHOST, key=b''):
ip=LOCALHOST, key=b'', transport='tcp'):
"""Generates a JSON config file, including the selection of random ports.
Parameters
Expand Down Expand Up @@ -54,17 +54,26 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0,
fname = tempfile.mktemp('.json')

# Find open ports as necessary.

ports = []
ports_needed = int(shell_port <= 0) + int(iopub_port <= 0) + \
int(stdin_port <= 0) + int(hb_port <= 0)
for i in xrange(ports_needed):
sock = socket.socket()
sock.bind(('', 0))
ports.append(sock)
for i, sock in enumerate(ports):
port = sock.getsockname()[1]
sock.close()
ports[i] = port
if transport == 'tcp':
for i in range(ports_needed):
sock = socket.socket()
sock.bind(('', 0))
ports.append(sock)
for i, sock in enumerate(ports):
port = sock.getsockname()[1]
sock.close()
ports[i] = port
else:
N = 1
for i in range(ports_needed):
while os.path.exists("%s-%s" % (ip, str(N))):
N += 1
ports.append(N)
N += 1
if shell_port <= 0:
shell_port = ports.pop(0)
if iopub_port <= 0:
Expand All @@ -81,6 +90,7 @@ def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0,
)
cfg['ip'] = ip
cfg['key'] = bytes_to_str(key)
cfg['transport'] = transport

with open(fname, 'w') as f:
f.write(json.dumps(cfg, indent=2))
Expand Down
24 changes: 16 additions & 8 deletions IPython/zmq/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Imports
#-----------------------------------------------------------------------------

import os
import socket
import sys
from threading import Thread
Expand All @@ -28,21 +29,28 @@
class Heartbeat(Thread):
"A simple ping-pong style heartbeat that runs in a thread."

def __init__(self, context, addr=(LOCALHOST, 0)):
def __init__(self, context, addr=('tcp', LOCALHOST, 0)):
Thread.__init__(self)
self.context = context
self.ip, self.port = addr
self.transport, self.ip, self.port = addr
if self.port == 0:
s = socket.socket()
# '*' means all interfaces to 0MQ, which is '' to socket.socket
s.bind(('' if self.ip == '*' else self.ip, 0))
self.port = s.getsockname()[1]
s.close()
if addr[0] == 'tcp':
s = socket.socket()
# '*' means all interfaces to 0MQ, which is '' to socket.socket
s.bind(('' if self.ip == '*' else self.ip, 0))
self.port = s.getsockname()[1]
s.close()
elif addr[0] == 'ipc':
while os.path.exists(self.ip + '-' + self.port):
self.port = self.port + 1
else:
raise ValueError("Unrecognized zmq transport: %s" % addr[0])
self.addr = (self.ip, self.port)
self.daemon = True

def run(self):
self.socket = self.context.socket(zmq.REP)
self.socket.bind('tcp://%s:%i' % self.addr)
c = ':' if self.transport == 'tcp' else '-'
self.socket.bind('%s://%s' % (self.transport, self.ip) + c + str(self.port))
zmq.device(zmq.FORWARDER, self.socket, self.socket)

32 changes: 25 additions & 7 deletions IPython/zmq/kernelapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
from IPython.utils.localinterfaces import LOCALHOST
from IPython.utils.path import filefind
from IPython.utils.py3compat import str_to_bytes
from IPython.utils.traitlets import (Any, Instance, Dict, Unicode, Integer, Bool,
DottedObjectName)
from IPython.utils.traitlets import (
Any, Instance, Dict, Unicode, Integer, Bool, CaselessStrEnum,
DottedObjectName,
)
from IPython.utils.importstring import import_item
# local imports
from IPython.zmq.entry_point import write_connection_file
Expand Down Expand Up @@ -109,6 +111,7 @@ def _parent_appname_changed(self, name, old, new):
self.config_file_specified = False

# connection info:
transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
ip = Unicode(LOCALHOST, config=True,
help="Set the IP or interface on which the kernel will listen.")
hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]")
Expand Down Expand Up @@ -154,11 +157,12 @@ def init_poller(self):
self.poller = ParentPollerUnix()

def _bind_socket(self, s, port):
iface = 'tcp://%s' % self.ip
if port <= 0:
iface = '%s://%s' % (self.transport, self.ip)
if port <= 0 and self.transport == 'tcp':
port = s.bind_to_random_port(iface)
else:
s.bind(iface + ':%i'%port)
c = ':' if self.transport == 'tcp' else '-'
s.bind(iface + c + str(port))
return port

def load_connection_file(self):
Expand All @@ -174,6 +178,7 @@ def load_connection_file(self):
with open(fname) as f:
s = f.read()
cfg = json.loads(s)
self.transport = cfg.get('transport', self.transport)
if self.ip == LOCALHOST and 'ip' in cfg:
# not overridden by config or cl_args
self.ip = cfg['ip']
Expand All @@ -191,7 +196,7 @@ def write_connection_file(self):
cf = os.path.join(self.profile_dir.security_dir, self.connection_file)
else:
cf = self.connection_file
write_connection_file(cf, ip=self.ip, key=self.session.key,
write_connection_file(cf, ip=self.ip, key=self.session.key, transport=self.transport,
shell_port=self.shell_port, stdin_port=self.stdin_port, hb_port=self.hb_port,
iopub_port=self.iopub_port)

Expand All @@ -204,6 +209,19 @@ def cleanup_connection_file(self):
os.remove(cf)
except (IOError, OSError):
pass

self._cleanup_ipc_files()

def _cleanup_ipc_files(self):
"""cleanup ipc files if we wrote them"""
if self.transport != 'ipc':
return
for port in (self.shell_port, self.iopub_port, self.stdin_port, self.hb_port):
ipcfile = "%s-%i" % (self.ip, port)
try:
os.remove(ipcfile)
except (IOError, OSError):
pass

def init_connection_file(self):
if not self.connection_file:
Expand Down Expand Up @@ -238,7 +256,7 @@ def init_heartbeat(self):
# heartbeat doesn't share context, because it mustn't be blocked
# by the GIL, which is accessed by libzmq when freeing zero-copy messages
hb_ctx = zmq.Context()
self.heartbeat = Heartbeat(hb_ctx, (self.ip, self.hb_port))
self.heartbeat = Heartbeat(hb_ctx, (self.transport, self.ip, self.hb_port))
self.hb_port = self.heartbeat.port
self.log.debug("Heartbeat REP Channel on port: %i"%self.hb_port)
self.heartbeat.start()
Expand Down

0 comments on commit 25951fa

Please sign in to comment.