Skip to content

Commit

Permalink
Merge pull request #63 from anton-ryzhov/graceful-connection-shutdown
Browse files Browse the repository at this point in the history
Graceful connection shutdown in CLI
  • Loading branch information
ionelmc committed Apr 8, 2021
2 parents 73be8ff + facb2c1 commit 2ee1f36
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Expand Up @@ -4,6 +4,7 @@ Changelog

dev
* Simplify connection closing code
* Graceful connection shutdown in ``manhole-cli``

1.7.0 (2021-03-22)
------------------
Expand Down
2 changes: 2 additions & 0 deletions src/manhole/__init__.py
Expand Up @@ -363,6 +363,8 @@ def handle_repl(locals):
namespace.update(locals)
try:
ManholeConsole(namespace).interact()
except SystemExit:
pass
finally:
for attribute in ['last_type', 'last_value', 'last_traceback']:
try:
Expand Down
62 changes: 22 additions & 40 deletions src/manhole/cli.py
Expand Up @@ -6,7 +6,6 @@
import os
import re
import readline
import select
import signal
import socket
import sys
Expand Down Expand Up @@ -69,41 +68,26 @@ def parse_signal(value):


class ConnectionHandler(threading.Thread):
def __init__(self, timeout, sock, read_fd=None, wait_the_end=True):
def __init__(self, sock, is_closing):
super(ConnectionHandler, self).__init__()
self.sock = sock
self.read_fd = read_fd
self.conn_fd = sock.fileno()
self.timeout = timeout
self.should_run = True
self._poller = select.poll()
self.wait_the_end = wait_the_end
self.is_closing = is_closing

def run(self):
if self.read_fd is not None:
self._poller.register(self.read_fd, select.POLLIN)
self._poller.register(self.conn_fd, select.POLLIN)

while self.should_run:
self.poll()
if self.wait_the_end:
t = time.time()
while time.time() - t < self.timeout:
self.poll()

def poll(self):
milliseconds = self.timeout * 1000
for fd, _ in self._poller.poll(milliseconds):
if fd == self.conn_fd:
data = self.sock.recv(1024*1024)
while True:
try:
data = self.sock.recv(1024**2)
if not data:
break
sys.stdout.write(data.decode('utf8'))
sys.stdout.flush()
readline.redisplay()
elif fd == self.read_fd:
data = os.read(self.read_fd, 1024)
self.sock.sendall(data)
else:
raise RuntimeError("Unknown FD %s" % fd)
except socket.timeout:
pass

if not self.is_closing.is_set():
# Break waiting for input()
os.kill(os.getpid(), signal.SIGINT)


def main():
Expand Down Expand Up @@ -138,21 +122,19 @@ def main():
print("Failed to connect to %r: Timeout" % uds_path, file=sys.stderr)
sys.exit(5)

read_fd, write_fd = os.pipe()

thread = ConnectionHandler(args.timeout, sock, read_fd, not sys.stdin.isatty())
is_closing = threading.Event()
thread = ConnectionHandler(sock, is_closing)
thread.start()

try:
while thread.is_alive():
try:
data = input()
except EOFError:
break
os.write(write_fd, data.encode('utf8'))
os.write(write_fd, b'\n')
except KeyboardInterrupt:
data = input()
data += '\n'
sock.sendall(data.encode('utf8'))
except (EOFError, KeyboardInterrupt):
pass
finally:
thread.should_run = False
is_closing.set()
sock.shutdown(socket.SHUT_WR)
thread.join()
sock.close()

0 comments on commit 2ee1f36

Please sign in to comment.