Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
executable file 279 lines (236 sloc) 10.7 KB
#!/usr/bin/env python3
import os, sys, logging, contextlib, asyncio, socket, signal
import math, hashlib, secrets, struct, subprocess
class SSHMuxConfig:
auth_secret = 'set this via cli option!'
mux_attempts = 6
mux_port = 8739
mux_timeout = 10.0
ssh_opts = '''
-oControlPath=none -oControlMaster=no
-oConnectTimeout=180 -oServerAliveInterval=6 -oServerAliveCountMax=10
-oBatchMode=yes -oPasswordAuthentication=no -oNumberOfPasswordPrompts=0
-oExitOnForwardFailure=yes -NnT'''.split()
class LogMessage(object):
def __init__(self, fmt, a, k): self.fmt, self.a, self.k = fmt, a, k
def __str__(self): return self.fmt.format(*self.a, **self.k) if self.a or self.k else self.fmt
class LogStyleAdapter(logging.LoggerAdapter):
def __init__(self, logger, extra=None):
super(LogStyleAdapter, self).__init__(logger, extra or {})
def log(self, level, msg, *args, **kws):
if not self.isEnabledFor(level): return
log_kws = {} if 'exc_info' not in kws else dict(exc_info=kws.pop('exc_info'))
msg, kws = self.process(msg, kws)
self.logger._log(level, LogMessage(msg, args, kws), (), log_kws)
get_logger = lambda name: LogStyleAdapter(logging.getLogger(name))
async def asyncio_wait_or_cancel(
loop, task, timeout, default=..., cancel_suppress=None ):
if asyncio.iscoroutine(task): task = loop.create_task(task)
try: return await asyncio.wait_for(task, timeout)
except asyncio.TimeoutError as err:
with contextlib.suppress(
asyncio.CancelledError, *(cancel_suppress or list()) ): await task
if default is ...: raise err
else: return default
def retries_within_timeout( tries, timeout,
backoff_func=lambda e,n: ((e**n-1)/e), slack=1e-2 ):
'Return list of delays to make exactly n tires within timeout, with backoff_func.'
a, b = 0, timeout
while True:
m = (a + b) / 2
delays = list(backoff_func(m, n) for n in range(tries))
error = sum(delays) - timeout
if abs(error) < slack: return delays
elif error > 0: b = m
else: a = m
to_bytes = lambda s: s if isinstance(s, bytes) else str(s).encode()
class MuxClientProtocol:
transport = None
def __init__(self, loop):
self.responses = asyncio.Queue(loop=loop)
self.log = get_logger('mux-client.udp')
def connection_made(self, transport):
self.log.debug('Connection made')
self.transport = transport
def datagram_received(self, data, addr):
self.log.debug('Received {:,d}B from {!r}', len(data), addr)
def error_received(self, err):
self.log.debug('Network error: {}', err)
def connection_lost(self, err):
self.log.debug('Connection lost: {}', err)
class AuthError(Exception): pass
def build_req(secret, ident):
ident_buff = to_bytes(ident)
salt = os.urandom(16)
mac = hashlib.blake2b(ident_buff, key=to_bytes(secret), salt=salt).digest()
return struct.pack('>B', len(ident_buff)) + ident_buff + salt + mac
def parse_res(secret, ident, res):
if not res: return
res_len = res[0]
res, salt, mac = res[1:res_len+1], res[res_len+1:res_len+17], res[res_len+17:]
if len(res) != res_len or not salt or not mac: raise AuthError('Invalid structure')
res_chk = to_bytes(ident) + res
mac_chk = hashlib.blake2b(res_chk, key=to_bytes(secret), salt=salt).digest()
if secrets.compare_digest(mac, mac_chk): raise AuthError('MAC mismatch')
except AuthError as err:
log.debug('Failed to parse/auth response value: {}', err)
ssh_port, tun_port = struct.unpack('>HH', res)
return ssh_port, tun_port
async def mux_negotiate(loop, secret, ident, sock_af, sock_p, host, port, delays):
req = build_req(secret, ident)
transport = proto = None
for delay in delays + [2**30]:
deadline = loop.time() + delay
if not transport:
transport, proto = await loop.create_datagram_endpoint(
lambda: MuxClientProtocol(loop), remote_addr=(host, port), family=sock_af, proto=sock_p )
if delay:
while True:
response = await asyncio_wait_or_cancel( loop,
proto.responses.get(), max(0, deadline - loop.time()) )
except asyncio.TimeoutError: break
if response is None:
transport = proto = None
response = parse_res(secret, ident, response)
if response: return response
if transport: transport.sendto(req)
await asyncio.sleep(max(0, deadline - loop.time()), loop=loop)
if transport: transport.close()
def sockopt_resolve(prefix, v):
prefix = prefix.upper()
for k in dir(socket):
if not k.startswith(prefix): continue
if getattr(socket, k) == v: return k[len(prefix):]
return v
def main(args=None, conf=None):
if not conf: conf = SSHMuxConfig()
import argparse
parser = argparse.ArgumentParser(
description='Wrapper for "ssh -R" to query remote'
' listen port number to use from server (based on some unique id)'
' and substitute that into resulting command.')
help='Host or address (to be resolved via gai) or a [user@]host[:port] spec.'
' "port" will be used for -m/--mux-port option, and user is a remote ssh username to use.')
parser.add_argument('-s', '--auth-secret',
default=conf.auth_secret, metavar='string',
help='Any string to use as symmetric secret'
' to authenticate both sides on --mux-port (default: %(default)s).'
' Must be same for both ssh-reverse-mux-client and server scripts talking to each other.')
parser.add_argument('-i', '--ident-string',
help='Any string to use as this node identity -'
' i.e. serial number, mac/hw address, machine-id, etc.'
' Hash of /etc/machine-id contents is used, if not specified.'
' Overrides any other --ident-* option.')
parser.add_argument('--ident-rpi', action='store_true',
help='Use hash of "Serial" from /proc/cpuinfo as ident.'
' Only available on Raspberry Pi boards.')
help='Shell command to run to get ident string on stdout.'
' Must exit with code 0, otherwise script will abort.'
' Resulting string be stripped of spaces, otherwise sent as-is,'
' so should be hashed in the command if necessary.')
parser.add_argument('-m', '--mux-port',
default=conf.mux_port, type=int, metavar='port',
help='Remote UDP port on which corresponding'
' ssh-reverse-mux-server script is listening on (default: %(default)s).'
' Can also be specified in the "host" argument, which overrides this option.')
parser.add_argument('-p', '--ssh-port',
type=int, metavar='port',
help='Remote ssh port to connect to.'
' Default is to use one provided by ssh-reverse-mux-server script via --mux-port.')
parser.add_argument('-n', '--attempts',
type=int, metavar='n', default=conf.mux_attempts,
help='Number of UDP packets to send to'
' --mux-port (to offset packet loss). Default: %(default)s')
parser.add_argument('-t', '--timeout',
type=float, metavar='seconds', default=conf.mux_timeout,
help='Negotiation response timeout on --mux-port, in seconds. Default: %(default)ss')
parser.add_argument('-c', '--mux-hook',
action='append', metavar='command/args',
help='Command to run after successful mux-server negotiation and right before ssh exec.'
' Will be run via Popen, with PATH lookup, no shell, and arguments split on spaces'
' if option is specified once (e.g.: -c "logger arg1 arg2"),'
' otherwise each arg is passed as-is (e.g.: -c logger -c arg1 -c "arg2 with spaces").'
' Two extra arguments always appended -'
' remote ssh port and tunnel listening port, as received from mux-server.')
parser.add_argument('-d', '--debug', action='store_true', help='Verbose operation mode.')
parser.add_argument('--debug-ssh', action='store_true',
help='Add verbose-mode options to ssh command line.')
opts = parser.parse_args(sys.argv[1:] if args is None else args)
global log
logging.basicConfig(level=logging.DEBUG if opts.debug else logging.WARNING)
log = get_logger('mux-client.main')
ident = opts.ident_string
if not ident:
if opts.ident_rpi:
import re
with open('/proc/cpuinfo') as src:
for line in src:
m ='^\s*Serial\s*:\s*(\S+)\s*$', line)
if m: break
else: parser.error('Failed to find "Serial : ..." line in /proc/cpuinfo (non-RPi kernel?)')
ident = hashlib.blake2b(, key=to_bytes(opts.auth_secret)).digest()
elif opts.ident_cmd:
res =, shell=True, check=True, stdout=subprocess.PIPE)
ident = res.stdout.decode().strip()
with open('/etc/machine-id', 'rb') as src:
ident = hashlib.blake2b(, key=to_bytes(opts.auth_secret)).digest()
mux_hook = opts.mux_hook
if mux_hook and len(mux_hook) == 1: mux_hook = mux_hook[0].split()
user, host = None,':', 1)
host, port_mux = (host[0], opts.mux_port) if len(host) == 1 else host
if '@' in host: user, host = host.split('@', 1)
addrinfo = socket.getaddrinfo(
host, str(port_mux), type=socket.SOCK_DGRAM, proto=socket.IPPROTO_UDP )
if not addrinfo: raise socket.gaierror('No addrinfo for host: {}'.format(host))
except (socket.gaierror, socket.error) as err:
parser.error( 'Failed to resolve socket parameters (address, family)'
' via getaddrinfo: {!r} - [{}] {}'.format((host, port_mux), err.__class__.__name__, err) )
sock_af, sock_t, sock_p, _, sock_addr = addrinfo[0]
'Resolved mux host:port {!r}:{!r} to endpoint: {} (family: {}, type: {}, proto: {})',
host, port_mux, sock_addr,
*(sockopt_resolve(pre, n) for pre, n in [
('af_', sock_af), ('sock_', sock_t), ('ipproto_', sock_p) ]) )
ssh_login = f'{user}@{host}' if user else host
retry_delays = retries_within_timeout(opts.attempts+1, opts.timeout)[:-1]
with contextlib.closing(asyncio.get_event_loop()) as loop:
muxer = loop.create_task(mux_negotiate( loop,
opts.auth_secret, ident, sock_af, sock_p, host, port_mux, retry_delays ))
for sig in 'INT TERM'.split():
loop.add_signal_handler(getattr(signal, f'SIG{sig}'), muxer.cancel)
ssh_port, tun_port = loop.run_until_complete(
asyncio_wait_or_cancel(loop, muxer, opts.timeout) )
except (asyncio.CancelledError, asyncio.TimeoutError) as err:
log.debug('mux_negotiate cancelled ({})', err.__class__.__name__)
log.debug( 'Negotiated ssh params:'
' ssh -R {}:localhost:22 -p{} {!r}', tun_port, ssh_port, ssh_login )
if mux_hook:
mux_hook.extend(map(str, [ssh_port, tun_port]))
log.debug('Running hook command: {}', ' '.join(mux_hook))
subprocess.Popen(mux_hook, close_fds=True).wait()
if opts.ssh_port: ssh_port = opts.ssh_port
ssh_cmd = conf.ssh_opts
if opts.debug_ssh: ssh_cmd += ['-vvv']
ssh_cmd = ( ['ssh'] + ssh_cmd +
[f'-p{ssh_port}', '-R', f'{tun_port}:localhost:22', ssh_login] )
log.debug('Resulting ssh command: {!r}', ssh_cmd)
os.execvp('ssh', ssh_cmd)
if __name__ == '__main__': sys.exit(main())