Skip to content

Commit

Permalink
tree: Python 3 support (#337)
Browse files Browse the repository at this point in the history
Feed the expat reader with bytes instead of string. The expat reader
already had an explicit encoding defined (utf-8) so changes are minor.

Fix pickle and base64 encoding exception handling, which is slightly
different in Python 3.

Fix tree tests for Python 3.

Change-Id: I790ab4fb1abc772249fd7cdc9c6467f96fdb1d01
  • Loading branch information
thiell committed Aug 8, 2017
1 parent ca3454c commit 05033db
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 43 deletions.
20 changes: 11 additions & 9 deletions lib/ClusterShell/Communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import cPickle

import base64
import binascii
import logging
import os
import xml.sax
Expand Down Expand Up @@ -113,10 +114,9 @@ def endElement(self, name):
self.msg_queue.appendleft(EndMessage())

def characters(self, content):
"""read content characters"""
"""read content characters (always decoded string)"""
if self._draft is not None:
content = content.decode(ENCODING)
self._draft.data_update(content)
self._draft.data_update(content.encode(ENCODING))

def msg_available(self):
"""return whether a message is available for delivery or not"""
Expand Down Expand Up @@ -220,7 +220,7 @@ def ev_read(self, worker):
"""channel has data to read"""
raw = worker.current_msg
try:
self._parser.feed(raw + '\n')
self._parser.feed(raw + b'\n')
except SAXParseException as ex:
self.logger.error("SAXParseException: %s: %s", ex.getMessage(), raw)
# Warning: do not send malformed raw message back
Expand All @@ -245,7 +245,7 @@ def send(self, msg):
"""write an outgoing message as its XML representation"""
#self.logger.debug('SENDING to worker %s: "%s"', id(self.worker),
# msg.xml())
self.worker.write(msg.xml() + '\n', sname=self.SNAME_WRITER)
self.worker.write(msg.xml() + b'\n', sname=self.SNAME_WRITER)

def start(self):
"""initialization logic"""
Expand Down Expand Up @@ -284,15 +284,17 @@ def data_encode(self, inst):
# and are ignored.
line_length = int(os.environ.get('CLUSTERSHELL_GW_B64_LINE_LENGTH',
DEFAULT_B64_LINE_LENGTH))
self.data = '\n'.join(encoded[pos:pos+line_length]
for pos in range(0, len(encoded), line_length))
self.data = b'\n'.join(encoded[pos:pos+line_length]
for pos in range(0, len(encoded), line_length))

def data_decode(self):
"""deserialize a previously encoded instance and return it"""
# NOTE: name is confusing, data_decode() returns pickle-decoded bytes
# (encoded string) and not (decoded) string...
# if self.data is None then an exception is raised here
try:
return cPickle.loads(base64.b64decode(self.data))
except (EOFError, TypeError):
except (EOFError, TypeError, cPickle.UnpicklingError, binascii.Error):
# raised by cPickle.loads() if self.data is not valid
raise MessageProcessingError('Message %s has an invalid payload'
% self.ident)
Expand All @@ -301,7 +303,7 @@ def data_update(self, raw):
"""append data to the instance (used for deserialization)"""
if self.has_payload:
if self.data is None:
self.data = raw # first encoded packet
self.data = raw # first encoded packet
else:
self.data += raw
else:
Expand Down
6 changes: 4 additions & 2 deletions lib/ClusterShell/Propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,16 @@ def recv_ctl(self, msg):
metaworker = self.workers[msg.srcid]
if msg.type == StdOutMessage.ident:
nodeset = NodeSet(msg.nodes)
decoded = msg.data_decode() + '\n'
# msg.data_decode()'s name is a bit confusing, but returns
# pickle-decoded bytes (encoded string) and not string...
decoded = msg.data_decode() + b'\n'
for line in decoded.splitlines():
for node in nodeset:
metaworker._on_remote_node_msgline(node, line, 'stdout',
self.gateway)
elif msg.type == StdErrMessage.ident:
nodeset = NodeSet(msg.nodes)
decoded = msg.data_decode() + '\n'
decoded = msg.data_decode() + b'\n'
for line in decoded.splitlines():
for node in nodeset:
metaworker._on_remote_node_msgline(node, line, 'stderr',
Expand Down
4 changes: 2 additions & 2 deletions tests/TreeCopyTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def tearDown(self):
def test_copy(self):
"""test file copy setup in tree mode (1 gateway)"""
self.test_ok = False
self.tfile = make_temp_file("dummy")
self.tfile = make_temp_file(b"dummy")
# add leading '/' like clush so that WorkerTree knows it's a dir
task_self().copy(self.tfile.name,
join(dirname(self.tfile.name), ''),
Expand All @@ -77,7 +77,7 @@ def test_copy(self):
def test_rcopy(self):
"""test file rcopy setup in tree mode (1 gateway)"""
self.test_ok = False
self.tfile = make_temp_file("dummy-src")
self.tfile = make_temp_file(b"dummy-src")
self.tdir = make_temp_dir()
task_self().rcopy(self.tfile.name, self.tdir, "n60")
task_self().resume()
Expand Down
70 changes: 44 additions & 26 deletions tests/TreeGatewayTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@ def __init__(self):
self.task.resume()

def send(self, msg):
"""send msg to pseudo stdin"""
os.write(self.pipe_stdin[1], msg + '\n')
"""send msg (bytes) to pseudo stdin"""
os.write(self.pipe_stdin[1], msg + b'\n')

def send_str(self, msgstr):
"""send msg (string) to pseudo stdin"""
self.send(msgstr.encode())

def recv(self):
"""recv buf from pseudo stdout (blocking call)"""
Expand Down Expand Up @@ -101,11 +105,11 @@ def tearDown(self):
#
def channel_send_start(self):
"""send starting channel tag"""
self.gateway.send('<channel version="%s">' % __version__)
self.gateway.send_str('<channel version="%s">' % __version__)

def channel_send_stop(self):
"""send channel ending tag"""
self.gateway.send("</channel>")
self.gateway.send_str("</channel>")

def channel_send_cfg(self, gateway):
"""send configuration part of channel"""
Expand All @@ -127,6 +131,7 @@ def _recvxml(self):
xml_msg = self.gateway.recv()
if len(xml_msg) == 0:
return None
self.assertTrue(type(xml_msg) is bytes)
self.parser.feed(xml_msg)

return self.xml_reader.pop_msg()
Expand Down Expand Up @@ -194,10 +199,15 @@ def _check_channel_err(self, sendmsg, errback, openchan=True,
self.assertEqual(self.chan.setup, True)

# send the erroneous message and test gateway reply
self.gateway.send(sendmsg)
self.gateway.send_str(sendmsg)
msg = self.recvxml(ErrorMessage)
self.assertEqual(msg.type, 'ERR')
self.assertEqual(msg.reason, errback)
try:
if not errback.search(msg.reason):
self.assertFalse(msg.reason)
except AttributeError:
# not a regex
self.assertEqual(msg.reason, errback)

# gateway should terminate channel session
if openchan:
Expand Down Expand Up @@ -283,14 +293,14 @@ def test_channel_err_xml_bad_char(self):
def test_channel_err_missingattr(self):
"""test gateway channel message bad attributes"""
self._check_channel_err(
'<message msgid="24" type="RET"></message>',
'<message msgid="24" nodes="foo" retcode="4" type="RET"></message>',
'Invalid "message" attributes: missing key "srcid"')

def test_channel_err_unexpected(self):
"""test gateway channel unexpected message"""
self._check_channel_err(
'<message type="ACK" ack="2" msgid="2"></message>',
'unexpected message: Message ACK (ack: 2, msgid: 2, type: ACK)')
re.compile(r'unexpected message: Message ACK \(.*ack: 2.*\)'))

def test_channel_err_cfg_missing_gw(self):
"""test gateway channel message missing gateway nodename"""
Expand All @@ -310,12 +320,20 @@ def test_channel_err_unexpected_pl(self):
'<message msgid="14" type="ERR" reason="test">FOO</message>',
'Got unexpected payload for Message ERR', setupchan=True)

def test_channel_err_badenc_pl(self):
"""test gateway channel message badly encoded payload"""
def test_channel_err_badenc_b2a_pl(self):
"""test gateway channel message badly encoded payload (base64)"""
# Generate TypeError (py2) or binascii.Error (py3)
self._check_channel_err(
'<message msgid="14" type="CFG" gateway="n1">bar</message>',
'Message CFG has an invalid payload')

def test_channel_err_badenc_pickle_pl(self):
"""test gateway channel message badly encoded payload (pickle)"""
# Generate pickle error
self._check_channel_err(
'<message msgid="14" type="CFG" gateway="n1">barm</message>',
'Message CFG has an invalid payload')

def test_channel_basic_abort(self):
"""test gateway channel aborted while opened"""
self.channel_send_start()
Expand All @@ -327,7 +345,7 @@ def test_channel_basic_abort(self):

def _check_channel_ctl_shell(self, command, target, stderr, remote,
reply_msg_class, reply_pattern,
write_string=None, timeout=-1, replycnt=1,
write_buf=None, timeout=-1, replycnt=1,
reply_rc=0):
"""helper to check channel shell action"""
self.channel_send_start()
Expand Down Expand Up @@ -359,12 +377,12 @@ def _check_channel_ctl_shell(self, command, target, stderr, remote,

self.recvxml(ACKMessage)

if write_string:
if write_buf:
ctl = ControlMessage(id(workertree))
ctl.action = 'write'
ctl.target = NodeSet(target)
ctl_data = {
'buf': write_string,
'buf': write_buf,
}
# Send write message
ctl.data_encode(ctl_data)
Expand Down Expand Up @@ -404,49 +422,49 @@ def _check_channel_ctl_shell(self, command, target, stderr, remote,
def test_channel_ctl_shell_local1(self):
"""test gateway channel shell stdout (stderr=False remote=False)"""
self._check_channel_ctl_shell("echo ok", "n10", False, False,
StdOutMessage, "ok")
StdOutMessage, b"ok")

def test_channel_ctl_shell_local2(self):
"""test gateway channel shell stdout (stderr=True remote=False)"""
self._check_channel_ctl_shell("echo ok", "n10", True, False,
StdOutMessage, "ok")
StdOutMessage, b"ok")

def test_channel_ctl_shell_local3(self):
"""test gateway channel shell stderr (stderr=True remote=False)"""
self._check_channel_ctl_shell("echo ok >&2", "n10", True, False,
StdErrMessage, "ok")
StdErrMessage, b"ok")

def test_channel_ctl_shell_mlocal1(self):
"""test gateway channel shell multi (remote=False)"""
self._check_channel_ctl_shell("echo ok", "n[10-49]", True, False,
StdOutMessage, "ok", replycnt=40)
StdOutMessage, b"ok", replycnt=40)

def test_channel_ctl_shell_mlocal2(self):
"""test gateway channel shell multi stderr (remote=False)"""
self._check_channel_ctl_shell("echo ok 1>&2", "n[10-49]", True, False,
StdErrMessage, "ok", replycnt=40)
StdErrMessage, b"ok", replycnt=40)

def test_channel_ctl_shell_mlocal3(self):
"""test gateway channel shell multi placeholder (remote=False)"""
self._check_channel_ctl_shell('echo node %h rank %n', "n[10-29]", True,
False, StdOutMessage,
re.compile(r"node n\d+ rank \d+"),
re.compile(br"node n\d+ rank \d+"),
replycnt=20)

def test_channel_ctl_shell_remote1(self):
"""test gateway channel shell stdout (stderr=False remote=True)"""
self._check_channel_ctl_shell("echo ok", "n10", False, True,
StdOutMessage,
re.compile("(Could not resolve hostname|"
"Name or service not known)"),
re.compile(b"(Could not resolve hostname|"
b"Name or service not known)"),
reply_rc=255)

def test_channel_ctl_shell_remote2(self):
"""test gateway channel shell stdout (stderr=True remote=True)"""
self._check_channel_ctl_shell("echo ok", "n10", True, True,
StdErrMessage,
re.compile("(Could not resolve hostname|"
"Name or service not known)"),
re.compile(b"(Could not resolve hostname|"
b"Name or service not known)"),
reply_rc=255)

def test_channel_ctl_shell_timeo1(self):
Expand All @@ -457,14 +475,14 @@ def test_channel_ctl_shell_timeo1(self):
def test_channel_ctl_shell_wrloc1(self):
"""test gateway channel write (stderr=False remote=False)"""
self._check_channel_ctl_shell("cat", "n10", False, False,
StdOutMessage, "ok", write_string="ok\n")
StdOutMessage, b"ok", write_buf=b"ok\n")

def test_channel_ctl_shell_wrloc2(self):
"""test gateway channel write (stderr=True remote=False)"""
self._check_channel_ctl_shell("cat", "n10", True, False,
StdOutMessage, "ok", write_string="ok\n")
StdOutMessage, b"ok", write_buf=b"ok\n")

def test_channel_ctl_shell_mwrloc1(self):
"""test gateway channel write multi (remote=False)"""
self._check_channel_ctl_shell("cat", "n[10-49]", True, False,
StdOutMessage, "ok", write_string="ok\n")
StdOutMessage, b"ok", write_buf=b"ok\n")
13 changes: 9 additions & 4 deletions tests/TreeTaskTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
import os
from textwrap import dedent
import unittest

from ClusterShell.Task import task_self
Expand All @@ -24,8 +25,10 @@ def tearDown(self):
def test_shell_auto_tree_dummy(self):
"""test task shell auto tree"""
# initialize a dummy topology.conf file
topofile = make_temp_file(
'[Main]\n%s: dummy-gw\ndummy-gw: dummy-node\n' % HOSTNAME)
topofile = make_temp_file(dedent("""
[Main]
%s: dummy-gw
dummy-gw: dummy-node"""% HOSTNAME).encode())
task = task_self()
task.set_default("auto_tree", True)
task.TOPOLOGY_CONFIGS = [topofile.name]
Expand All @@ -51,8 +54,10 @@ def test_shell_auto_tree_noconf(self):
def test_shell_auto_tree_error(self):
"""test task shell auto tree [TopologyError]"""
# initialize an erroneous topology.conf file
topofile = make_temp_file(
'[Main]\n%s: dummy-gw\ndummy-gw: dummy-gw\n' % HOSTNAME)
topofile = make_temp_file(dedent("""
[Main]
%s: dummy-gw
dummy-gw: dummy-gw"""% HOSTNAME).encode())
task = task_self()
task.set_default("auto_tree", True)
task.TOPOLOGY_CONFIGS = [topofile.name]
Expand Down

0 comments on commit 05033db

Please sign in to comment.