Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLI: fix line buffering with Python 3 (#528) #537

Merged
merged 1 commit into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 31 additions & 30 deletions lib/ClusterShell/CLI/Clubak.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ClusterShell.NodeSet import set_std_group_resolver_config

from ClusterShell.CLI.Display import Display, THREE_CHOICES
from ClusterShell.CLI.Display import sys_stdin, sys_stdout
from ClusterShell.CLI.Display import sys_stdin
from ClusterShell.CLI.Error import GENERIC_ERRORS, handle_generic_error
from ClusterShell.CLI.OptionParser import OptionParser
from ClusterShell.CLI.Utils import nodeset_cmpkey
Expand All @@ -55,40 +55,38 @@ def display_tree(tree, disp, out):
reldepth = reldepths[depth] = reldepth + offset
nodeset = NodeSet.fromlist(keys)
if line_mode:
out.write(str(nodeset).encode() + b':\n')
out.write(str(nodeset) + ':\n')
else:
out.write(disp.format_header(nodeset, reldepth))
out.write(b' ' * reldepth + msgline + b'\n')
out.write(' ' * reldepth + bytes(msgline).decode(errors='replace')
+ '\n')
togh = nchildren != 1

def display(tree, disp, gather, trace_mode, enable_nodeset_key):
"""nicely display MsgTree instance `tree' content according to
`disp' Display object and `gather' boolean flag"""
out = sys_stdout()
try:
if trace_mode:
display_tree(tree, disp, out)
if trace_mode:
display_tree(tree, disp, sys.stdout)
sys.stdout.flush()
return
if gather:
if enable_nodeset_key:
# lambda to create a NodeSet from keys returned by walk()
ns_getter = lambda x: NodeSet.fromlist(x[1])
for nodeset in sorted((ns_getter(item) for item in tree.walk()),
key=nodeset_cmpkey):
disp.print_gather(nodeset, tree[nodeset[0]])
else:
if gather:
if enable_nodeset_key:
# lambda to create a NodeSet from keys returned by walk()
ns_getter = lambda x: NodeSet.fromlist(x[1])
for nodeset in sorted((ns_getter(item) for item in tree.walk()),
key=nodeset_cmpkey):
disp.print_gather(nodeset, tree[nodeset[0]])
else:
for msg, key in tree.walk():
disp.print_gather_keys(key, msg)
else:
if enable_nodeset_key:
# nodes are automagically sorted by NodeSet
for node in NodeSet.fromlist(tree.keys()).nsiter():
disp.print_gather(node, tree[str(node)])
else:
for key in tree.keys():
disp.print_gather_keys([ key ], tree[key])
finally:
out.flush()
for msg, key in tree.walk():
disp.print_gather_keys(key, msg)
else:
if enable_nodeset_key:
# nodes are automagically sorted by NodeSet
for node in NodeSet.fromlist(tree.keys()).nsiter():
disp.print_gather(node, tree[str(node)])
else:
for key in tree.keys():
disp.print_gather_keys([ key ], tree[key])

def clubak():
"""script subroutine"""
Expand Down Expand Up @@ -128,9 +126,11 @@ def clubak():
try:
linestripped = line.rstrip(b'\r\n')
if options.verbose or options.debug:
sys_stdout().write(b'INPUT ' + linestripped + b'\n')
sys.stdout.write('INPUT ' +
linestripped.decode(errors='replace') + '\n')
key, content = linestripped.split(separator, 1)
key = key.strip().decode() # NodeSet requires encoded string
# NodeSet requires encoded string
key = key.strip().decode(errors='replace')
if not key:
raise ValueError("no node found")
if enable_nodeset_key is False: # interpret-keys=never?
Expand All @@ -150,7 +150,8 @@ def clubak():
for node in keyset:
tree.add(node, content)
except ValueError as ex:
raise ValueError('%s: "%s"' % (ex, linestripped.decode()))
raise ValueError('%s: "%s"' %
(ex, linestripped.decode(errors='replace')))

if fast_mode:
# Messages per node have been aggregated, now add to tree one
Expand Down
52 changes: 31 additions & 21 deletions lib/ClusterShell/CLI/Display.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#
# Copyright (C) 2010-2015 CEA/DAM
# Copyright (C) 2023 Stephane Thiell <sthiell@stanford.edu>
#
# This file is part of ClusterShell.
#
Expand Down Expand Up @@ -42,17 +43,10 @@
else:
STRING_ENCODING = sys.getdefaultencoding()


# Python 3 compat: wrappers for standard streams
# Python 3 compat: wrapper for stdin
def sys_stdin():
return getattr(sys.stdin, 'buffer', sys.stdin)

def sys_stdout():
return getattr(sys.stdout, 'buffer', sys.stdout)

def sys_stderr():
return getattr(sys.stderr, 'buffer', sys.stderr)


class Display(object):
"""
Expand Down Expand Up @@ -106,7 +100,7 @@ def __init__(self, options, config=None, color=None):
# NO_COLOR takes precedence over CLI_COLORS. --color option always
# takes precedence over any environment variable.

if options.whencolor is None:
if options.whencolor is None and color is not False:
if (config is None) or (config.color == '' or config.color == 'auto'):
if 'NO_COLOR' not in os.environ:
color = self._has_cli_color()
Expand All @@ -122,8 +116,19 @@ def __init__(self, options, config=None, color=None):
color = True

self._color = color
self.out = sys_stdout()
self.err = sys_stderr()
# GH#528 enable line buffering
self.out = sys.stdout
try :
if not self.out.line_buffering:
self.out.reconfigure(line_buffering=True)
except AttributeError: # < py3.7
pass
self.err = sys.stderr
try :
if not self.err.line_buffering:
self.err.reconfigure(line_buffering=True)
except AttributeError: # < py3.7
pass

if self._color:
self.color_stdout_fmt = self.COLOR_STDOUT_FMT
Expand Down Expand Up @@ -198,7 +203,7 @@ def _format_nodeset(self, nodeset):
def format_header(self, nodeset, indent=0):
"""Format nodeset-based header."""
if not self.label:
return b""
return ""
indstr = " " * indent
nodecntstr = ""
if self.verbosity >= VERB_STD and self.node_count and len(nodeset) > 1:
Expand All @@ -207,23 +212,25 @@ def format_header(self, nodeset, indent=0):
(indstr, self.SEP,
indstr, self._format_nodeset(nodeset), nodecntstr,
indstr, self.SEP))
return hdr.encode(STRING_ENCODING) + b'\n'
return hdr + '\n'

def print_line(self, nodeset, line):
"""Display a line with optional label."""
linestr = line.decode(STRING_ENCODING, errors='replace') + '\n'
if self.label:
prefix = self.color_stdout_fmt % ("%s: " % nodeset)
self.out.write(prefix.encode(STRING_ENCODING) + line + b'\n')
self.out.write(prefix + linestr)
else:
self.out.write(line + b'\n')
self.out.write(linestr)

def print_line_error(self, nodeset, line):
"""Display an error line with optional label."""
linestr = line.decode(STRING_ENCODING, errors='replace') + '\n'
if self.label:
prefix = self.color_stderr_fmt % ("%s: " % nodeset)
self.err.write(prefix.encode(STRING_ENCODING) + line + b'\n')
self.err.write(prefix + linestr)
else:
self.err.write(line + b'\n')
self.err.write(linestr)

def print_gather(self, nodeset, obj):
"""Generic method for displaying nodeset/content according to current
Expand All @@ -242,7 +249,8 @@ def print_gather_keys(self, keys, obj):

def _print_content(self, nodeset, content):
"""Display a dshbak-like header block and content."""
self.out.write(self.format_header(nodeset) + bytes(content) + b'\n')
s = bytes(content).decode(STRING_ENCODING, errors='replace')
self.out.write(self.format_header(nodeset) + s + '\n')

def _print_diff(self, nodeset, content):
"""Display unified diff between remote gathered outputs."""
Expand Down Expand Up @@ -275,7 +283,7 @@ def _print_diff(self, nodeset, content):
else:
output += line
output += '\n'
self.out.write(output.encode(STRING_ENCODING))
self.out.write(output)

def _print_lines(self, nodeset, msg):
"""Display a MsgTree buffer by line with prefixed header."""
Expand All @@ -284,10 +292,12 @@ def _print_lines(self, nodeset, msg):
header = self.color_stdout_fmt % \
("%s: " % self._format_nodeset(nodeset))
for line in msg:
out.write(header.encode(STRING_ENCODING) + line + b'\n')
out.write(header + line.decode(STRING_ENCODING,
errors='replace') + '\n')
else:
for line in msg:
out.write(line + b'\n')
out.write(line.decode(STRING_ENCODING,
errors='replace') + '\n')

def vprint(self, level, message):
"""Utility method to print a message if verbose level is high
Expand Down
7 changes: 7 additions & 0 deletions tests/CLIClubakTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ def test_001_verbosity(self):
self._clubak_t(["-v", "-b"], b"foo: bar\n", _outfmt_verb('foo'), 0)
# no node count with -q
self._clubak_t(["-q", "-b"], b"foo[1-5]: bar\n", _outfmt('foo[1-5]'), 0)
# non-printable characters replaced by the replacement character
self._clubak_t(["-L"], b"foo:\xffbar\n", "foo: \ufffdbar\n".encode(), 0)
self._clubak_t(["-d", "-L"], b"foo:\xf8bar\n",
'INPUT foo:\ufffdbar\nfoo: \ufffdbar\n'.encode(), 0,
b'line_mode=True gather=False tree_depth=1\n')

def test_002_b(self):
"""test clubak (gather -b)"""
Expand Down Expand Up @@ -118,6 +123,8 @@ def test_006_tree(self):
"""test clubak (tree mode --tree)"""
self._clubak_t(["--tree"], b"foo: bar\n", _outfmt("foo"))
self._clubak_t(["--tree", "-L"], b"foo: bar\n", b"foo:\n bar\n")
self._clubak_t(["--tree", "-L"], b"foo: \xf8bar\n",
"foo:\n \ufffdbar\n".encode())
stdin_buf = dedent("""foo1:bar
foo2:bar
foo1:moo
Expand Down
77 changes: 71 additions & 6 deletions tests/CLIDisplayTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile
import unittest
import os
from io import BytesIO
from io import StringIO

from ClusterShell.CLI.Display import Display, THREE_CHOICES, VERB_STD
from ClusterShell.CLI.OptionParser import OptionParser
Expand Down Expand Up @@ -61,8 +61,8 @@ def testDisplay(self):
options.label = label
disp = Display(options)
# inhibit output
disp.out = BytesIO()
disp.err = BytesIO()
disp.out = StringIO()
disp.err = StringIO()
# test print_* methods...
disp.print_line(ns, b"foo bar")
disp.print_line_error(ns, b"foo bar")
Expand Down Expand Up @@ -103,16 +103,16 @@ def testDisplayRegroup(self):

disp = Display(options, color=False)
self.assertEqual(disp.regroup, True)
disp.out = BytesIO()
disp.err = BytesIO()
disp.out = StringIO()
disp.err = StringIO()
self.assertEqual(disp.line_mode, False)

ns = NodeSet("hostfoo")

# nodeset.regroup() is performed by print_gather()
disp.print_gather(ns, b"message0\nmessage1\n")
self.assertEqual(disp.out.getvalue(),
b"---------------\n@all\n---------------\nmessage0\nmessage1\n\n")
"---------------\n@all\n---------------\nmessage0\nmessage1\n\n")
finally:
set_std_group_resolver(None)

Expand All @@ -131,3 +131,68 @@ def testDisplayClubak(self):
self.assertEqual(disp.maxrc, False)
self.assertEqual(disp.node_count, True)
self.assertEqual(disp.verbosity, VERB_STD)

def testDisplayDecodingErrors(self):
"""test CLI.Display (decoding errors)"""
parser = OptionParser("dummy")
parser.install_display_options()
options, _ = parser.parse_args([])
disp = Display(options, color=False)
disp.out = StringIO()
disp.err = StringIO()
self.assertEqual(bool(disp.gather), False)
self.assertEqual(disp.line_mode, False)
ns = NodeSet("node")
disp.print_line(ns, b"message0\n\xf8message1\n")
self.assertEqual(disp.out.getvalue(),
"node: message0\n\ufffdmessage1\n\n")
disp.print_line_error(ns, b"message0\n\xf8message1\n")
self.assertEqual(disp.err.getvalue(),
"node: message0\n\ufffdmessage1\n\n")

def testDisplayDecodingErrorsGather(self):
"""test CLI.Display (decoding errors, gather)"""
parser = OptionParser("dummy")
parser.install_display_options(dshbak_compat=True)
options, _ = parser.parse_args(["-b"])
disp = Display(options, color=False)
disp.out = StringIO()
disp.err = StringIO()
self.assertEqual(bool(disp.gather), True)
self.assertEqual(disp.line_mode, False)
ns = NodeSet("node")
disp._print_buffer(ns, b"message0\n\xf8message1\n")
self.assertEqual(disp.out.getvalue(),
"---------------\nnode\n---------------\nmessage0\n\ufffdmessage1\n\n")

def testDisplayDecodingErrorsLineMode(self):
"""test CLI.Display (decoding errors, line mode)"""
parser = OptionParser("dummy")
parser.install_display_options(dshbak_compat=True)
options, _ = parser.parse_args(["-b", "-L"])
disp = Display(options, color=False)
disp.out = StringIO()
disp.err = StringIO()
self.assertEqual(bool(disp.gather), True)
self.assertEqual(disp.label, True)
self.assertEqual(disp.line_mode, True)
ns = NodeSet("node")
disp.print_gather(ns, [b"message0\n", b"\xf8message1\n"])
self.assertEqual(disp.out.getvalue(),
"node: message0\n\nnode: \ufffdmessage1\n\n")

def testDisplayDecodingErrorsLineModeNoLabel(self):
"""test CLI.Display (decoding errors, line mode, no label)"""
parser = OptionParser("dummy")
parser.install_display_options(dshbak_compat=True)
options, _ = parser.parse_args(["-b", "-L", "-N"])
disp = Display(options, color=False)
disp.out = StringIO()
disp.err = StringIO()
self.assertEqual(bool(disp.gather), True)
self.assertEqual(disp.label, False)
self.assertEqual(disp.line_mode, True)
ns = NodeSet("node")
disp.print_gather(ns, [b"message0\n", b"\xf8message1\n"])
self.assertEqual(disp.out.getvalue(),
"message0\n\n\ufffdmessage1\n\n")
14 changes: 6 additions & 8 deletions tests/TLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,11 @@ class TBytesIO(BytesIO):

def __init__(self, initial_bytes=None):
if initial_bytes and type(initial_bytes) is not bytes:
initial_bytes = initial_bytes.encode('ascii')
initial_bytes = initial_bytes.encode()
BytesIO.__init__(self, initial_bytes)

def write(self, b):
if type(b) is bytes:
BytesIO.write(self, b)
else:
BytesIO.write(self, b.encode('ascii'))
def write(self, s):
BytesIO.write(self, s.encode())

def isatty(self):
return False
Expand Down Expand Up @@ -104,8 +101,9 @@ def CLI_main(test, main, args, stdin, expected_stdout, expected_rc=0,
# should be read in text mode for some tests (eg. Nodeset).
sys.stdin = StringIO(stdin)

# Output: ClusterShell sends bytes to sys_stdout()/sys_stderr() and when
# print() is used, TBytesIO does a conversion to ascii.
# Output: ClusterShell writes to stdout/stderr using strings, but the tests
# expect bytes. TBytesIO is a wrapper that does the conversion until we
# migrate all tests to string.
sys.stdout = out = TBytesIO()
sys.stderr = err = TBytesIO()
sys.argv = args
Expand Down
Loading