Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
executable file 255 lines (210 sloc) 9.42 KB
#!/usr/bin/env python3
import os, sys, logging, contextlib, asyncio, socket, signal
import math, hashlib, secrets, struct, shelve, base64
class SSHMuxConfig:
auth_secret = 'set this via cli option!'
mux_attempts = 4
mux_port = 8739
mux_timeout = 5.0
ssh_port = 22
tunnel_port_range = '22000:22100'
ident_db_path = 'ssh-reverse-mux-ident.db'
class LogMessage(object):
def __init__(self, fmt, a, k): self.fmt, self.a, self.k = fmt, a, k
def __str__(self): return self.fmt.format(*self.a, **self.k) if self.a or self.k else self.fmt
class LogStyleAdapter(logging.LoggerAdapter):
def __init__(self, logger, extra=None):
super(LogStyleAdapter, self).__init__(logger, extra or {})
def log(self, level, msg, *args, **kws):
if not self.isEnabledFor(level): return
log_kws = {} if 'exc_info' not in kws else dict(exc_info=kws.pop('exc_info'))
msg, kws = self.process(msg, kws)
self.logger._log(level, LogMessage(msg, args, kws), (), log_kws)
get_logger = lambda name: LogStyleAdapter(logging.getLogger(name))
async def asyncio_wait_or_cancel(
loop, task, timeout, default=..., cancel_suppress=None ):
if asyncio.iscoroutine(task): task = loop.create_task(task)
try: return await asyncio.wait_for(task, timeout)
except asyncio.TimeoutError as err:
with contextlib.suppress(
asyncio.CancelledError, *(cancel_suppress or list()) ): await task
if default is ...: raise err
else: return default
def retries_within_timeout( tries, timeout,
backoff_func=lambda e,n: ((e**n-1)/e), slack=1e-2 ):
'Return list of delays to make exactly n tires within timeout, with backoff_func.'
a, b = 0, timeout
while True:
m = (a + b) / 2
delays = list(backoff_func(m, n) for n in range(tries))
error = sum(delays) - timeout
if abs(error) < slack: return delays
elif error > 0: b = m
else: a = m
to_bytes = lambda s: s if isinstance(s, bytes) else str(s).encode()
class MuxServerProtocol:
transport = None
def __init__(self, loop):
self.requests = asyncio.Queue(loop=loop)
self.log = get_logger('mux-server.udp')
def connection_made(self, transport):
self.log.debug('Connection made')
self.transport = transport
def datagram_received(self, data, addr):
self.log.debug('Received {:,d}B from {!r}', len(data), addr)
self.requests.put_nowait((data, addr))
# def error_received(self, err):
# self.log.debug('Network error: {}', err)
def connection_lost(self, err):
self.log.debug('Connection lost: {}', err)
class AuthError(Exception): pass
def parse_req(secret, req):
ident_len = req[0]
ident, salt, mac = req[1:ident_len+1], req[ident_len+1:ident_len+17], req[ident_len+17:]
if len(ident) != ident_len or not salt or not mac: raise AuthError('Invalid structure')
mac_chk = hashlib.blake2b(ident, key=to_bytes(secret), salt=salt).digest()
if not secrets.compare_digest(mac, mac_chk): raise AuthError('MAC mismatch')
except AuthError as err:
log.debug('Failed to parse/auth request value: {}', err)
return base64.urlsafe_b64encode(ident).decode()
def build_res(secret, ident, tun_port, ssh_port):
res = struct.pack('>HH', ssh_port, tun_port)
salt = os.urandom(16)
res_chk = to_bytes(ident) + res
mac = hashlib.blake2b(res_chk, key=to_bytes(secret), salt=salt).digest()
return struct.pack('>B', len(res)) + res + salt + mac
def ident_repr(ident):
try: ident_t, ident_dec = 'str', base64.urlsafe_b64decode(ident).decode()
except UnicodeDecodeError: ident_t, ident_dec = 'b64', ident
return f'[{ident_t}] {ident_dec!r}'
async def mux_send(loop, transport, response, addr, delays):
for delay in delays:
transport.sendto(response, addr)
await asyncio.sleep(delay, loop=loop)
async def mux_listen( loop, secret, ident_db,
sock_af, sock_p, host, port, tun_port_a, tun_port_b, ssh_port, delays ):
def tun_ports_iter_func():
tun_ports_used = set(ident_db.values())
for port in range(tun_port_a, tun_port_b + 1):
if port not in tun_ports_used: yield port
tun_ports_iter = tun_ports_iter_func()
responses = dict()
transport, proto = await loop.create_datagram_endpoint(
lambda: MuxServerProtocol(loop), local_addr=(host, port), family=sock_af, proto=sock_p )
while True:
req, addr = await proto.requests.get()
ident = parse_req(secret, req)
if not ident: continue
if ident in responses:
if not responses[ident].done(): continue
await responses[ident]
tun_port = ident_db.get(ident)
if not tun_port or not tun_port_a <= tun_port <= tun_port_b:
try: tun_port = next(tun_ports_iter)
except StopIteration:
log.debug( 'No more ports to allocate'
' ident: {} (addr={})', ident_repr(ident), addr )
ident_db[ident] = tun_port
response = build_res(secret, ident, tun_port, ssh_port)
'Allocated [tun={}, ssh={}] for ident: {} (addr={})',
tun_port, ssh_port, ident_repr(ident), addr )
responses[ident] = loop.create_task(
mux_send(loop, transport, response, addr, delays) )
for response in responses.values(): await response
def sockopt_resolve(prefix, v):
prefix = prefix.upper()
for k in dir(socket):
if not k.startswith(prefix): continue
if getattr(socket, k) == v: return k[len(prefix):]
return v
def main(args=None, conf=None):
if not conf: conf = SSHMuxConfig()
import argparse
parser = argparse.ArgumentParser(
description='Multiplexer for "ssh -R" connections,'
' directing each one to unique port(s) according to provided ident-sting.')
parser.add_argument('bind', nargs='?', default='::',
help='Host or address (to be resolved via gai) to listen on.'
' Default is to use "::" wildcard IPv4/IPv6 binding.')
parser.add_argument('-s', '--auth-secret',
default=conf.auth_secret, metavar='string',
help='Any string to use as symmetric secret'
' to authenticate both sides on --mux-port (default: %(default)s).'
' Must be same for both ssh-reverse-mux-client and server scripts talking to each other.')
parser.add_argument('-i', '--ident-db',
default=conf.ident_db_path, metavar='path',
help='Path to db to store all the seen clients to, for persistent port allocation.'
' Default: %(default)s')
parser.add_argument('-l', '--ident-list',
action='store_true', help='List stored ident-port mappings and exit.')
parser.add_argument('-m', '--mux-port',
default=conf.mux_port, type=int, metavar='port',
help='Local UDP port to listen on for muxer requests from clients (default: %(default)s).'
' Can also be specified in the "bind" argument, which overrides this option.')
parser.add_argument('-p', '--ssh-port',
type=int, metavar='port', default=conf.ssh_port,
help='Local sshd port to send to clients (default: %(default)s).')
parser.add_argument('-r', '--tunnel-port-range',
metavar='port_from:port_to', default=conf.tunnel_port_range,
help='Range in which to allocate'
' reverse-tunnel listening ports, inclusive. Default: %(default)s')
parser.add_argument('-n', '--attempts',
type=int, metavar='n', default=conf.mux_attempts,
help='Number of UDP packets to send from'
' --mux-port in response to clients (to offset packet loss). Default: %(default)s')
parser.add_argument('-t', '--timeout',
type=float, metavar='seconds', default=conf.mux_timeout,
help='Negotiation response timeout on --mux-port, in seconds. Default: %(default)ss')
parser.add_argument('-d', '--debug', action='store_true', help='Verbose operation mode.')
opts = parser.parse_args(sys.argv[1:] if args is None else args)
global log
logging.basicConfig(level=logging.DEBUG if opts.debug else logging.WARNING)
log = get_logger('mux-server.main')
ident_db =, 'c')
if opts.ident_list:
for ident, tun_port in ident_db.items():
print('port {} :: {}'.format(tun_port, ident_repr(ident)))
host, port_mux = opts.bind, opts.mux_port
addrinfo = socket.getaddrinfo(
host, str(port_mux), type=socket.SOCK_DGRAM, proto=socket.IPPROTO_UDP )
if not addrinfo: raise socket.gaierror('No addrinfo for host: {}'.format(host))
except (socket.gaierror, socket.error) as err:
parser.error( 'Failed to resolve socket parameters (address, family)'
' via getaddrinfo: {!r} - [{}] {}'.format((host, port_mux), err.__class__.__name__, err) )
sock_af, sock_t, sock_p, _, sock_addr = addrinfo[0]
'Resolved mux host:port {!r}:{!r} to endpoint: {} (family: {}, type: {}, proto: {})',
host, port_mux, sock_addr,
*(sockopt_resolve(pre, n) for pre, n in [
('af_', sock_af), ('sock_', sock_t), ('ipproto_', sock_p) ]) )
tun_port_a, tun_port_b = map(int, opts.tunnel_port_range.split(':', 1))
for p in tun_port_a, tun_port_b:
if not 0 < p < 65535: raise ValueError(f'Out of range: {p!r}')
if tun_port_a > tun_port_b: raise ValueError(tun_port_a, tun_port_b)
except Exception as err:
parser.error(f'Failed to parse tunnel port range: [{err.__class__.__name__}] {err}')
'Parsed tunnel port range: {}-{} ({} port(s))',
tun_port_a, tun_port_b, tun_port_b - tun_port_a + 1 )
retry_delays = retries_within_timeout(opts.attempts, opts.timeout)
with contextlib.closing(asyncio.get_event_loop()) as loop:
muxer = loop.create_task(mux_listen( loop, opts.auth_secret, ident_db,
sock_af, sock_p, host, port_mux, tun_port_a, tun_port_b, opts.ssh_port, retry_delays ))
for sig in 'INT TERM'.split():
loop.add_signal_handler(getattr(signal, f'SIG{sig}'), muxer.cancel)
try: return loop.run_until_complete(muxer)
except asyncio.CancelledError: return
if __name__ == '__main__': sys.exit(main())