Skip to content

Commit

Permalink
Support unconditionally logging stderr/stdout, no matter the spec.
Browse files Browse the repository at this point in the history
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
  • Loading branch information
ezyang committed Apr 10, 2019
1 parent c8fe2b1 commit 35335db
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]
select = B,C,E,F,P,T4,W,B9
max-line-length = 80
max-line-length = 200
### DEFAULT IGNORES FOR 4-space INDENTED PROJECTS ###
# E127, E128 are hard to silence in certain nested formatting situations.
# E203 doesn't work for slicing
Expand Down
31 changes: 31 additions & 0 deletions emitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env python3

# Helper binary used by test_shell.py to print interleaved sequences
# of strings to stderr/stdout.

import sys
from typing import Sequence, TypeVar, Tuple, Iterator
import itertools

T = TypeVar('T')


def grouper(n: int, iterable: Sequence[T]) -> Iterator[Tuple[T, ...]]:
"grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return itertools.zip_longest(*args)


for mode, payload in grouper(2, sys.argv[1:]):
if mode == 'e':
print(payload, end='', file=sys.stderr)
sys.stderr.flush()
elif mode == 'o':
print(payload, end='', file=sys.stdout)
sys.stdout.flush()
elif mode == 'r':
# Big enough payload to exceed default chunk limit
print("." * (4096 * 128), file=sys.stdout)
sys.stdout.flush()
else:
raise RuntimeError('Unrecognized mode')
115 changes: 89 additions & 26 deletions ghstack/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import subprocess
import os
import logging
from typing import Dict, Sequence, Optional, TypeVar, Union, Any, overload, IO
from typing import Dict, Sequence, Optional, TypeVar, Union, Any, overload, IO, Tuple
import asyncio
import sys


# Shell commands generally return str, but with exitcode=True
Expand All @@ -23,6 +25,8 @@ def log_command(args: Sequence[str]) -> None:
*args: the list of command line arguments you want to run
env: the dictionary of environment variable settings for the command
"""
# TODO: Irritatingly, this doesn't insert quotes for shell
# metacharacters like exclamation marks or parentheses.
cmd = subprocess.list2cmdline(args).replace("\n", "\\n")
logging.info("$ " + cmd)

Expand Down Expand Up @@ -81,7 +85,7 @@ def __init__(self,
self.testing = testing
self.testing_time = 1112911993

def sh(self, *args: str,
def sh(self, *args: str, # noqa: C901
env: Optional[Dict[str, str]] = None,
stderr: _HANDLE = None,
input: Optional[str] = None,
Expand Down Expand Up @@ -117,34 +121,93 @@ def sh(self, *args: str,
log_command(args)
if env is not None:
env = merge_dicts(dict(os.environ), env)
p = subprocess.Popen(
args,
stdout=stdout,
stdin=stdin,
stderr=stderr,
cwd=self.cwd,
env=env
)
input_bytes = None
if input is not None:
input_bytes = input.encode('utf-8')
out, err = p.communicate(input_bytes)
if err is not None:
# NB: Not debug; we always want to show this to user.
logging.info(err)

# The things we do for logging...
#
# - I didn't make a PTY, so programs are going to give
# output assuming there isn't a terminal at the other
# end. This is less nice for direct terminal use, but
# it's better for logging (since we get to dispense
# with the control codes).
#
# - We assume line buffering. This is kind of silly but
# what are you going to do.

async def process_stream(proc_stream: asyncio.StreamReader, setting: _HANDLE,
default_stream: IO[Any]) -> bytes:
output = []
while True:
try:
line = await proc_stream.readuntil()
except asyncio.LimitOverrunError as e:
line = await proc_stream.readexactly(e.consumed)
except asyncio.IncompleteReadError as e:
line = e.partial
if not line:
break
output.append(line)
if setting == subprocess.PIPE:
pass
elif setting == subprocess.STDOUT:
sys.stdout.buffer.write(line)
elif isinstance(setting, int):
os.write(setting, line)
elif setting is None:
default_stream.write(line)
else:
setting.write(line)
return b''.join(output)

async def feed_input(stdin_writer: Optional[asyncio.StreamWriter]) -> None:
if stdin_writer is None:
return
if not input:
return
stdin_writer.write(input.encode('utf-8'))
await stdin_writer.drain()
stdin_writer.close()

async def run() -> Tuple[int, bytes, bytes]:
proc = await asyncio.create_subprocess_exec(
*args,
stdin=stdin,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=self.cwd,
env=env,
)
assert proc.stdout is not None
assert proc.stderr is not None
_, out, err, _ = await asyncio.gather(
feed_input(proc.stdin),
process_stream(proc.stdout, stdout, sys.stdout.buffer),
process_stream(proc.stderr, stderr, sys.stdout.buffer),
proc.wait()
)
return (proc.returncode, out, err)

loop = asyncio.get_event_loop()
returncode, out, err = loop.run_until_complete(run())

# NB: Not debug; we always want to show this to user.
if err:
logging.debug("# stderr:\n" + err.decode(errors="backslashreplace"))
if out:
logging.debug(
("# stdout:\n" if err else "")
+ out.decode(errors="backslashreplace").replace('\0', '\\0'))

if exitcode:
logging.debug("Exit code: {}".format(p.returncode))
return p.returncode == 0
if p.returncode != 0:
logging.debug("Exit code: {}".format(returncode))
return returncode == 0
if returncode != 0:
raise RuntimeError(
"{} failed with exit code {}"
.format(' '.join(args), p.returncode)
.format(' '.join(args), returncode)
)
if out is not None:
r = out.decode()
assert isinstance(r, str)
logging.debug(r.replace('\0', '\\0'))
return r

if stdout == subprocess.PIPE:
return out.decode() # do a strict decode for actual return
else:
return None

Expand Down
1 change: 1 addition & 0 deletions run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ set -e
flake8-3 ghstack
mypy --strict --config=detailed-mypy.ini ghstack test_ghstack.py
python3 test_expecttest.py
python3 test_shell.py
python3 test_ghstack.py
7 changes: 3 additions & 4 deletions test_ghstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ class TestGh(expecttest.TestCase):
sh: ghstack.shell.Shell

def setUp(self) -> None:
tmp_dir = tempfile.mkdtemp()

# Set up a "parent" repository with an empty initial commit that we'll operate on
upstream_dir = tempfile.mkdtemp()
if GH_KEEP_TMP:
Expand Down Expand Up @@ -903,7 +901,7 @@ def test_no_clobber(self) -> None:
print("### test_no_clobber")
self.sh.git("commit", "--allow-empty", "-m", "Commit 1\n\nOriginal message")
self.sh.test_tick()
stack = self.gh('Initial 1')
self.gh('Initial 1')
self.sh.test_tick()
self.substituteRev("HEAD", "rCOM1")
self.substituteRev("origin/gh/ezyang/1/head", "rMRG1")
Expand Down Expand Up @@ -979,7 +977,7 @@ def test_update_fields(self) -> None:
print("### test_update_fields")
self.sh.git("commit", "--allow-empty", "-m", "Commit 1\n\nOriginal message")
self.sh.test_tick()
stack = self.gh('Initial 1')
self.gh('Initial 1')
self.sh.test_tick()
self.substituteRev("HEAD", "rCOM1")
self.substituteRev("origin/gh/ezyang/1/head", "rMRG1")
Expand Down Expand Up @@ -1244,4 +1242,5 @@ def test_remove_bottom_commit(self) -> None:


if __name__ == '__main__':
logging.basicConfig(level=logging.DEBUG, format='%(message)s')
unittest.main()
135 changes: 135 additions & 0 deletions test_shell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#!/usr/bin/env python3

import unittest
import logging
from typing import List, Any
from dataclasses import dataclass
import sys

import ghstack.expecttest as expecttest
import ghstack.shell


@dataclass
class ConsoleMsg:
pass


@dataclass
class out(ConsoleMsg):
msg: str


@dataclass
class err(ConsoleMsg):
msg: str


@dataclass
class big_dump(ConsoleMsg):
pass


class TestShell(expecttest.TestCase):
def setUp(self) -> None:
self.sh = ghstack.shell.Shell()

def emit(self, *payload: ConsoleMsg, **kwargs: Any) -> ghstack.shell._SHELL_RET:
args: List[str] = [sys.executable, 'emitter.py']
for p in payload:
if isinstance(p, out):
args.extend(('o', p.msg))
elif isinstance(p, err):
args.extend(('e', p.msg))
elif isinstance(p, big_dump):
args.extend(('r', '-'))
return self.sh.sh(*args, **kwargs)

def flog(self, cm: 'unittest._AssertLogsContext') -> str:
def redact(s: str) -> str:
s = s.replace(sys.executable, 'python')
return s
return '\n'.join(redact(r.getMessage()) for r in cm.records)

def test_stdout(self) -> None:
with self.assertLogs(level=logging.DEBUG) as cm:
self.emit(
out("arf\n")
)
self.assertExpected(self.flog(cm), '''\
$ python emitter.py o arf\\n
arf
''')

def test_stderr(self) -> None:
with self.assertLogs(level=logging.DEBUG) as cm:
self.emit(
err("arf\n")
)
self.assertExpected(self.flog(cm), '''\
$ python emitter.py e arf\\n
# stderr:
arf
''')

def test_stdout_passthru(self) -> None:
with self.assertLogs(level=logging.DEBUG) as cm:
self.emit(
out("arf\n"),
stdout=None
)
self.assertExpected(self.flog(cm), '''\
$ python emitter.py o arf\\n
arf
''')

def test_stdout_with_stderr_prefix(self) -> None:
# What most commands should look like
with self.assertLogs(level=logging.DEBUG) as cm:
self.emit(
err("Step 1...\n"),
err("Step 2...\n"),
err("Step 3...\n"),
out("out\n"),
stdout=None
)
self.assertExpected(self.flog(cm), '''\
$ python emitter.py e "Step 1...\\n" e "Step 2...\\n" e "Step 3...\\n" o out\\n
# stderr:
Step 1...
Step 2...
Step 3...
# stdout:
out
''')

def test_interleaved_stdout_stderr_passthru(self) -> None:
# NB: stdout is flushed in each of these cases
with self.assertLogs(level=logging.DEBUG) as cm:
self.emit(
out("A\n"),
err("B\n"),
out("C\n"),
err("D\n"),
stdout=None
)
self.assertExpected(self.flog(cm), '''\
$ python emitter.py o A\\n e B\\n o C\\n e D\\n
# stderr:
B
D
# stdout:
A
C
''')

def test_deadlock(self) -> None:
self.emit(
big_dump()
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 35335db

Please sign in to comment.