Skip to content

Commit

Permalink
Alternative SSH conn that binds the remote docker socket locally
Browse files Browse the repository at this point in the history
Signed-off-by: aiordache <anca.iordache@docker.com>
  • Loading branch information
aiordache committed Oct 23, 2020
1 parent 1cb8896 commit 787b4fd
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 110 deletions.
15 changes: 13 additions & 2 deletions docker/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
except ImportError:
pass

try:
from ..transport import SSHClientAdapter
except ImportError:
pass


class APIClient(
requests.Session,
Expand Down Expand Up @@ -161,11 +166,17 @@ def __init__(self, base_url=None, version=None,
)
self.mount('http+docker://', self._custom_adapter)
self.base_url = 'http+docker://localnpipe'
elif base_url.startswith('ssh://') and use_ssh_client:
self._custom_adapter = SSHClientAdapter(
base_url, timeout, pool_connections=num_pools
)
self.mount('http+docker://', self._custom_adapter)
self._unmount('http://', 'https://')
self.base_url = 'http+docker://localhost'
elif base_url.startswith('ssh://'):
try:
self._custom_adapter = SSHHTTPAdapter(
base_url, timeout, pool_connections=num_pools,
shell_out=use_ssh_client
base_url, timeout, pool_connections=num_pools
)
except NameError:
raise DockerException(
Expand Down
5 changes: 5 additions & 0 deletions docker/transport/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@
from .sshconn import SSHHTTPAdapter
except ImportError:
pass

try:
from .sshtunnel import SSHClientAdapter
except ImportError:
pass
126 changes: 18 additions & 108 deletions docker/transport/sshconn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import six
import logging
import os
import socket
import subprocess

from docker.transport.basehttpadapter import BaseHTTPAdapter
from .. import constants
Expand All @@ -22,104 +20,33 @@
RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer


class SSHSocket(socket.socket):
def __init__(self, host):
super(SSHSocket, self).__init__(
socket.AF_INET, socket.SOCK_STREAM)
self.host = host
self.port = None
if ':' in host:
self.host, self.port = host.split(':')
self.proc = None

def connect(self, **kwargs):
port = '' if not self.port else '-p {}'.format(self.port)
args = [
'ssh',
'-q',
self.host,
port,
'docker system dial-stdio'
]
self.proc = subprocess.Popen(
' '.join(args),
shell=True,
stdout=subprocess.PIPE,
stdin=subprocess.PIPE)

def _write(self, data):
if not self.proc or self.proc.stdin.closed:
raise Exception('SSH subprocess not initiated.'
'connect() must be called first.')
written = self.proc.stdin.write(data)
self.proc.stdin.flush()
return written

def sendall(self, data):
self._write(data)

def send(self, data):
return self._write(data)

def recv(self):
if not self.proc:
raise Exception('SSH subprocess not initiated.'
'connect() must be called first.')
return self.proc.stdout.read()

def makefile(self, mode):
if not self.proc or self.proc.stdout.closed:
self.connect()
return self.proc.stdout

def close(self):
if not self.proc or self.proc.stdin.closed:
return
self.proc.stdin.write(b'\n\n')
self.proc.stdin.flush()
self.proc.terminate()


class SSHConnection(httplib.HTTPConnection, object):
def __init__(self, ssh_transport=None, timeout=60, host=None):
def __init__(self, ssh_transport, timeout=60):
super(SSHConnection, self).__init__(
'localhost', timeout=timeout
)
self.ssh_transport = ssh_transport
self.timeout = timeout
self.ssh_host = host

def connect(self):
if self.ssh_transport:
sock = self.ssh_transport.open_session()
sock.settimeout(self.timeout)
sock.exec_command('docker system dial-stdio')
else:
sock = SSHSocket(self.ssh_host)
sock.settimeout(self.timeout)
sock.connect()

sock = self.ssh_transport.open_session()
sock.settimeout(self.timeout)
sock.exec_command('docker system dial-stdio')
self.sock = sock


class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
scheme = 'ssh'

def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None):
def __init__(self, ssh_client, timeout=60, maxsize=10):
super(SSHConnectionPool, self).__init__(
'localhost', timeout=timeout, maxsize=maxsize
)
self.ssh_transport = None
self.ssh_transport = ssh_client.get_transport()
self.timeout = timeout
if ssh_client:
self.ssh_transport = ssh_client.get_transport()
self.ssh_host = host
self.ssh_port = None
if ':' in host:
self.ssh_host, self.ssh_port = host.split(':')

def _new_conn(self):
return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host)
return SSHConnection(self.ssh_transport, self.timeout)

# When re-using connections, urllib3 calls fileno() on our
# SSH channel instance, quickly overloading our fd limit. To avoid this,
Expand Down Expand Up @@ -151,21 +78,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter):
]

def __init__(self, base_url, timeout=60,
pool_connections=constants.DEFAULT_NUM_POOLS,
shell_out=True):
self.ssh_client = None
if not shell_out:
self._create_paramiko_client(base_url)
self._connect()

self.ssh_host = base_url.lstrip('ssh://')
self.timeout = timeout
self.pools = RecentlyUsedContainer(
pool_connections, dispose_func=lambda p: p.close()
)
super(SSHHTTPAdapter, self).__init__()

def _create_paramiko_client(self, base_url):
pool_connections=constants.DEFAULT_NUM_POOLS):
logging.getLogger("paramiko").setLevel(logging.WARNING)
self.ssh_client = paramiko.SSHClient()
base_url = six.moves.urllib_parse.urlparse(base_url)
Expand Down Expand Up @@ -195,18 +108,18 @@ def _create_paramiko_client(self, base_url):
self.ssh_client.load_system_host_keys()
self.ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy())

self._connect()
self.timeout = timeout
self.pools = RecentlyUsedContainer(
pool_connections, dispose_func=lambda p: p.close()
)
super(SSHHTTPAdapter, self).__init__()

def _connect(self):
if self.ssh_client:
self.ssh_client.connect(**self.ssh_params)
self.ssh_client.connect(**self.ssh_params)

def get_connection(self, url, proxies=None):
with self.pools.lock:
if not self.ssh_client:
return SSHConnectionPool(
ssh_client=self.ssh_client,
timeout=self.timeout,
host=self.ssh_host
)
pool = self.pools.get(url)
if pool:
return pool
Expand All @@ -216,15 +129,12 @@ def get_connection(self, url, proxies=None):
self._connect()

pool = SSHConnectionPool(
ssh_client=self.ssh_client,
timeout=self.timeout,
host=self.ssh_host
self.ssh_client, self.timeout
)
self.pools[url] = pool

return pool

def close(self):
super(SSHHTTPAdapter, self).close()
if self.ssh_client:
self.ssh_client.close()
self.ssh_client.close()
83 changes: 83 additions & 0 deletions docker/transport/sshtunnel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import signal
import six
import subprocess
import tempfile
import time

try:
import requests.packages.urllib3 as urllib3
except ImportError:
import urllib3

from .. import constants

from docker.transport.basehttpadapter import BaseHTTPAdapter
from .unixconn import UnixHTTPConnectionPool


RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer


class SSHClientAdapter(BaseHTTPAdapter):
def __init__(self, socket_url, timeout=60,
pool_connections=constants.DEFAULT_NUM_POOLS):
self.ssh_host = socket_url.lstrip('ssh://')
self.ssh_port = None
if ':' in self.ssh_host:
self.ssh_host, self.ssh_port = self.ssh_host.split(':')
self.timeout = timeout
self.socket_path = None

self.__create_ssh_tunnel()

self.pools = RecentlyUsedContainer(
pool_connections, dispose_func=lambda p: p.close()
)
super(SSHClientAdapter, self).__init__()

def __create_ssh_tunnel(self):
self.temp_dir = tempfile.TemporaryDirectory()
self.socket_path = os.path.join(self.temp_dir.name, "docker.sock")

port = '' if not self.ssh_port else '-p {}'.format(self.ssh_port)
# bind remote engine socket locally to a temporary file
args = [
'ssh',
'-NL',
'{}:/var/run/docker.sock'.format(self.socket_path),
self.ssh_host,
port
]
self.proc = subprocess.Popen(
' '.join(args),
shell=True, preexec_fn=lambda:signal.signal(signal.SIGINT, signal.SIG_IGN))
count = .0

while not os.path.exists(self.socket_path):
time.sleep(.1)
count = count + 0.1
if count > self.timeout:
raise Exception("Failed to connect via SSH")

def get_connection(self, url, proxies=None):
with self.pools.lock:
pool = self.pools.get(url)
if pool:
return pool

pool = UnixHTTPConnectionPool(
url, self.socket_path, self.timeout
)
self.pools[url] = pool

return pool

def request_url(self, request, proxies):
return request.path_url

def close(self):
super(SSHClientAdapter, self).close()
if self.proc:
self.proc.terminate()
self.proc.wait()

0 comments on commit 787b4fd

Please sign in to comment.