Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
executable file 187 lines (156 sloc) 7.08 KB
#!/usr/bin/env python3
import itertools as it, operator as op, functools as ft
from collections import OrderedDict
import os, sys, logging, asyncio, socket, signal, contextlib
class LogMessage(object):
def __init__(self, fmt, a, k): self.fmt, self.a, self.k = fmt, a, k
def __str__(self): return self.fmt.format(*self.a, **self.k) if self.a or self.k else self.fmt
class LogStyleAdapter(logging.LoggerAdapter):
def __init__(self, logger, extra=None):
super(LogStyleAdapter, self).__init__(logger, extra or {})
def log(self, level, msg, *args, **kws):
if not self.isEnabledFor(level): return
log_kws = {} if 'exc_info' not in kws else dict(exc_info=kws.pop('exc_info'))
msg, kws = self.process(msg, kws)
self.logger._log(level, LogMessage(msg, args, kws), (), log_kws)
get_logger = lambda name: LogStyleAdapter(logging.getLogger(name))
async def conn_discard(conn):
conn.cancel()
with contextlib.suppress(asyncio.CancelledError, socket.error): await conn
async def run_tping(loop, *args, **kws):
conns, exit_code = OrderedDict(), 0
pinger = loop.create_task(_run_tping(loop, conns, *args, **kws))
def sig_handler(code):
nonlocal exit_code
exit_code = code
pinger.cancel()
for sig, code in ('SIGINT', 1), ('SIGTERM', 0):
loop.add_signal_handler(getattr(signal, sig), ft.partial(sig_handler, code))
with contextlib.suppress(asyncio.CancelledError): await pinger
for conn in conns.values(): await conn_discard(conn)
return exit_code
async def _run_tping(loop, conns, sock_addr, sock_af, delay_retry, timeout):
ts_retry = loop.time() + delay_retry
while True:
ts = loop.time()
conns[ts] = t = loop.create_task(asyncio.open_connection(
*sock_addr[:2], family=sock_af, proto=socket.IPPROTO_TCP ))
t.ts_conn = ts
ts_abort = ts - timeout
for ts_conn in list(conns):
if ts_conn <= ts_abort:
await conn_discard(conns[ts_conn])
del conns[ts_conn]
else: break
while ts <= ts_retry:
delay, done = max(0, ts_retry - ts), None
if conns:
done, pending = await asyncio.wait( conns.values(),
timeout=delay, return_when=asyncio.FIRST_COMPLETED )
else: await asyncio.sleep(delay)
if done:
reader = writer = None
for res in done:
with contextlib.suppress(socket.error):
reader, writer = await res
writer.transport.abort()
del conns[res.ts_conn]
if reader or writer: break
done = None
ts = loop.time()
if done: break
ts_retry += delay_retry
def str_part(s, sep, default=None):
'Examples: str_part("user@host", "<@", "root"), str_part("host:port", ":>")'
c = sep.strip('<>')
if sep.strip(c) == '<': return (default, s) if c not in s else s.split(c, 1)
else: return (s, default) if c not in s else s.rsplit(c, 1)
def sockopt_resolve(prefix, v):
prefix = prefix.upper()
for k in dir(socket):
if not k.startswith(prefix): continue
if getattr(socket, k) == v: return k[len(prefix):]
return v
port_default = 22
def main(args=None):
import argparse
parser = argparse.ArgumentParser(
description='Tool to "ping" tcp port until connection can be established.'
' Main use is waiting for something to come online,'
' e.g. waiting for when ssh port is accessible after remote host reboots.'
' Unlike common icmp ping tools, does not need elevated privileges.')
group = parser.add_argument_group('Destination spec')
group.add_argument('host',
help='Host/address (to be resolve via gai) or [user@]host[:port] spec.'
' User part can be specified for -s/--ssh option and is ignored otherwise.')
group.add_argument('port_arg', nargs='?',
help='TCP service/port name/number to connect to.'
' Can also be specified as part of the host or -p/--port option.'
' Default: {}'.format(port_default))
group.add_argument('-p', '--port', metavar='tcp-port',
help='TCP service/port name/number to connect to.'
' Can also be specified as part of the host or a second positional argument.'
' Default: {}'.format(port_default))
group = parser.add_argument_group('Delays/timeouts')
group.add_argument('-r', '--retry-delay',
type=float, metavar='seconds', default=1.0,
help='Delay between trying to establish new connections.'
' Does not take any other started connections/handshakes into account.'
' Default: %(default)ss')
group.add_argument('-t', '--timeout',
type=float, metavar='seconds', default=3.0,
help='Delay after which to consider establishing'
' connection as timed-out and cancel it. Default: %(default)ss')
group = parser.add_argument_group('Extra actions')
group.add_argument('-s', '--ssh', action='store_true',
help='Exec "ssh -p<port> <user@host>" after successful tcp connection instead of exit.'
' Default user is "root", if not specified in the "host" argument.')
group.add_argument('-o', '--ssh-opts', action='append',
help='Extra options for ssh command, if -s/--ssh flag is used.'
' Will be split on spaces, unless option is used multiple times.'
' If has no spaces and does not start with a dash, single dash is prepended,'
" e.g. -o=NnT -> -o=-NnT, purely for convenience with few short options."
' Can contain special "@@@" token (as a separate arg) to'
' substitute user@host in its place, otherwise it will be appended.'
" Example: -o='-NnT -i mykey @@@ -- cat /etc/hosts'")
group = parser.add_argument_group('Debug/misc')
group.add_argument('-d', '--debug', action='store_true', help='Verbose operation mode.')
opts = parser.parse_args(sys.argv[1:] if args is None else args)
global log
logging.basicConfig(level=logging.DEBUG if opts.debug else logging.WARNING)
log = get_logger('main')
user, host = str_part(opts.host, '<@', 'root')
host, port = str_part(host, ':>')
ports = set(filter(None, [opts.port_arg, opts.port, port]))
if ports:
if len(ports) > 1:
parser.error( 'Conflicting (different) port'
' specs: {}'.format(', '.join(str(p) for p in ports)) )
port = ports.pop()
else: port = port_default
ssh_opts = opts.ssh_opts or list()
if len(ssh_opts) == 1:
ssh_opts = ssh_opts[0].strip().split()
if len(ssh_opts) == 1 and not ssh_opts[0].startswith('-'): ssh_opts[0] = f'-{ssh_opts[0]}'
try:
addrinfo = socket.getaddrinfo(
host, str(port), type=socket.SOCK_STREAM, proto=socket.IPPROTO_TCP )
if not addrinfo: raise socket.gaierror('No addrinfo for host: {}'.format(host))
except (socket.gaierror, socket.error) as err:
parser.error( 'Failed to resolve socket parameters (address, family)'
' via getaddrinfo: {!r} - [{}] {}'.format((host, port), err.__class__.__name__, err) )
sock_af, sock_t, _, _, sock_addr = addrinfo[0]
log.debug( 'Resolved host:port {!r}:{!r} to endpoint: {} (family: {}, type: {})',
host, port, sock_addr, sockopt_resolve('af_', sock_af), sockopt_resolve('sock_', sock_t) )
with contextlib.closing(asyncio.get_event_loop()) as loop:
err = loop.run_until_complete(run_tping(
loop, sock_addr, sock_af, opts.retry_delay, opts.timeout ))
if err: return err
if opts.ssh:
host, port = sock_addr[:2]
cmd = ['ssh', f'-p{port:d}'] + ssh_opts
if '@@@' not in cmd: cmd.append('@@@')
cmd[cmd.index('@@@')] = f'{user}@{host}'
log.debug('Running ssh command: {}', ' '.join(cmd))
os.execvp('ssh', cmd)
if __name__ == '__main__': sys.exit(main())