diff --git a/.gitignore b/.gitignore index 67d4bbb5..e7c48d8a 100644 --- a/.gitignore +++ b/.gitignore @@ -46,7 +46,6 @@ dbconfig.json *.egg *.egg-info venv/* -.venv/* node_modules/ # Test artifacts *_flag_if_job_task_fail diff --git a/ci/ssh/docker-compose.yml b/ci/ssh/docker-compose.yml index b0308f45..fed6a6c2 100644 --- a/ci/ssh/docker-compose.yml +++ b/ci/ssh/docker-compose.yml @@ -12,6 +12,17 @@ services: - "22" volumes: - ssh_config:/root/.ssh + jumphost: + image: takeyamajp/ubuntu-sshd:ubuntu22.04 + build: . + container_name: jumphost + hostname: jumphost + environment: + ROOT_PASSWORD: dpdispatcher + expose: + - "22" + volumes: + - ssh_config:/root/.ssh test: image: python:3.10 tty: true @@ -25,6 +36,7 @@ services: - ../..:/dpdispatcher depends_on: - server + - jumphost volumes: ssh_config: diff --git a/ci/ssh/start-ssh.sh b/ci/ssh/start-ssh.sh index 71688e10..0a63dc3e 100755 --- a/ci/ssh/start-ssh.sh +++ b/ci/ssh/start-ssh.sh @@ -2,5 +2,27 @@ docker compose up -d --no-build +# Set up SSH keys on server docker exec server /bin/bash -c "ssh-keygen -b 2048 -t rsa -f /root/.ssh/id_rsa -q -N \"\" && cat /root/.ssh/id_rsa.pub >> /root/.ssh/authorized_keys && chmod 600 /root/.ssh/authorized_keys" docker exec server /bin/bash -c "mkdir -p /dpdispatcher_working" +docker exec server /bin/bash -c "mkdir -p /tmp/rsync_test" + +# Set up SSH keys on jumphost and configure it to access server +docker exec jumphost /bin/bash -c "ssh-keygen -b 2048 -t rsa -f /root/.ssh/id_rsa -q -N \"\" && cat /root/.ssh/id_rsa.pub >> /root/.ssh/authorized_keys && chmod 600 /root/.ssh/authorized_keys" + +# Copy keys between containers to enable jump host functionality +# Get the public key from jumphost and add it to server's authorized_keys +docker exec jumphost /bin/bash -c "cat /root/.ssh/id_rsa.pub" | docker exec -i server /bin/bash -c "cat >> /root/.ssh/authorized_keys" + +# Get the public key from test (which shares volume with server) and add it to jumphost authorized_keys +docker exec test /bin/bash -c "cat /root/.ssh/id_rsa.pub" | docker exec -i jumphost /bin/bash -c "cat >> /root/.ssh/authorized_keys" + +# Configure SSH client settings for known hosts to avoid host key verification +docker exec test /bin/bash -c "echo 'StrictHostKeyChecking no' >> /root/.ssh/config && echo 'UserKnownHostsFile /dev/null' >> /root/.ssh/config" +docker exec jumphost /bin/bash -c "echo 'StrictHostKeyChecking no' >> /root/.ssh/config && echo 'UserKnownHostsFile /dev/null' >> /root/.ssh/config" +docker exec server /bin/bash -c "echo 'StrictHostKeyChecking no' >> /root/.ssh/config && echo 'UserKnownHostsFile /dev/null' >> /root/.ssh/config" + +# Install rsync on all containers +docker exec test /bin/bash -c "apt-get -y update && apt-get -y install rsync" +docker exec jumphost /bin/bash -c "apt-get -y update && apt-get -y install rsync" +docker exec server /bin/bash -c "apt-get -y update && apt-get -y install rsync" diff --git a/doc/context.md b/doc/context.md index 56413eb9..bb09a5ee 100644 --- a/doc/context.md +++ b/doc/context.md @@ -33,6 +33,28 @@ It's suggested to generate [SSH keys](https://help.ubuntu.com/community/SSH/Open Note that `SSH` context is [non-login](https://www.gnu.org/software/bash/manual/html_node/Bash-Startup-Files.html), so `bash_profile` files will not be executed outside the submission script. +### SSH Jump Host (Bastion Server) + +For connecting to internal servers through a jump host (bastion server), SSH context supports jump host configuration. This allows connecting to internal servers that are not directly accessible from the internet. + +Specify the ProxyCommand directly using {dargs:argument}`proxy_command `: + +```json +{ + "context_type": "SSHContext", + "remote_profile": { + "hostname": "internal-server.company.com", + "username": "user", + "key_filename": "/path/to/internal_key", + "proxy_command": "ssh -W %h:%p -i /path/to/jump_key jumpuser@bastion.company.com" + } +} +``` + +The proxy command uses OpenSSH ProxyCommand syntax. `%h` and `%p` are replaced with the target hostname and port. + +This configuration establishes the connection path: Local → Jump Host → Target Server. + ## Bohrium {dargs:argument}`context_type `: `Bohrium` diff --git a/dpdispatcher/contexts/ssh_context.py b/dpdispatcher/contexts/ssh_context.py index 2f662f46..66440928 100644 --- a/dpdispatcher/contexts/ssh_context.py +++ b/dpdispatcher/contexts/ssh_context.py @@ -45,6 +45,7 @@ def __init__( tar_compress=True, look_for_keys=True, execute_command=None, + proxy_command=None, ): self.hostname = hostname self.username = username @@ -58,6 +59,7 @@ def __init__( self.tar_compress = tar_compress self.look_for_keys = look_for_keys self.execute_command = execute_command + self.proxy_command = proxy_command self._keyboard_interactive_auth = False self._setup_ssh() @@ -142,7 +144,12 @@ def _setup_ssh(self): # transport = self.ssh.get_transport() sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(self.timeout) - sock.connect((self.hostname, self.port)) + + # Use ProxyCommand if configured (either directly or via jump host parameters) + if self.proxy_command is not None: + sock = paramiko.ProxyCommand(self.proxy_command) + else: + sock.connect((self.hostname, self.port)) # Make a Paramiko Transport object using the socket ts = paramiko.Transport(sock) @@ -340,6 +347,9 @@ def arginfo(): "enable searching for discoverable private key files in ~/.ssh/" ) doc_execute_command = "execute command after ssh connection is established." + doc_proxy_command = ( + "ProxyCommand to use for SSH connection through intermediate servers." + ) ssh_remote_profile_args = [ Argument("hostname", str, optional=False, doc=doc_hostname), Argument("username", str, optional=False, doc=doc_username), @@ -388,6 +398,13 @@ def arginfo(): default=None, doc=doc_execute_command, ), + Argument( + "proxy_command", + [str, type(None)], + optional=True, + default=None, + doc=doc_proxy_command, + ), ] ssh_remote_profile_format = Argument( "ssh_session", dict, ssh_remote_profile_args @@ -396,23 +413,37 @@ def arginfo(): def put(self, from_f, to_f): if self.rsync_available: + # For rsync, we need to use %h:%p placeholders for target host/port + proxy_cmd_rsync = None + if self.proxy_command is not None: + proxy_cmd_rsync = self.proxy_command.replace( + f"{self.hostname}:{self.port}", "%h:%p" + ) return rsync( from_f, self.remote + ":" + to_f, port=self.port, key_filename=self.key_filename, timeout=self.timeout, + proxy_command=proxy_cmd_rsync, ) return self.sftp.put(from_f, to_f) def get(self, from_f, to_f): if self.rsync_available: + # For rsync, we need to use %h:%p placeholders for target host/port + proxy_cmd_rsync = None + if self.proxy_command is not None: + proxy_cmd_rsync = self.proxy_command.replace( + f"{self.hostname}:{self.port}", "%h:%p" + ) return rsync( self.remote + ":" + from_f, to_f, port=self.port, key_filename=self.key_filename, timeout=self.timeout, + proxy_command=proxy_cmd_rsync, ) return self.sftp.get(from_f, to_f) diff --git a/dpdispatcher/utils/utils.py b/dpdispatcher/utils/utils.py index 33f586cd..4a75fe5d 100644 --- a/dpdispatcher/utils/utils.py +++ b/dpdispatcher/utils/utils.py @@ -2,6 +2,7 @@ import hashlib import hmac import os +import shlex import struct import subprocess import time @@ -89,6 +90,7 @@ def rsync( port: int = 22, key_filename: Optional[str] = None, timeout: Union[int, float] = 10, + proxy_command: Optional[str] = None, ): """Call rsync to transfer files. @@ -104,6 +106,8 @@ def rsync( identity file name timeout : int, default=10 timeout for ssh + proxy_command : str, optional + ProxyCommand to use for SSH connection Raises ------ @@ -124,20 +128,30 @@ def rsync( ] if key_filename is not None: ssh_cmd.extend(["-i", key_filename]) + + # Use proxy_command if provided + if proxy_command is not None: + ssh_cmd.extend(["-o", f"ProxyCommand={proxy_command}"]) + + # Properly escape the SSH command for rsync's -e option + ssh_cmd_str = " ".join(shlex.quote(part) for part in ssh_cmd) + cmd = [ "rsync", # -a: archieve # -z: compress "-az", "-e", - " ".join(ssh_cmd), + ssh_cmd_str, "-q", from_file, to_file, ] - ret, out, err = run_cmd_with_all_output(cmd, shell=False) + # Convert to string for shell=True + cmd_str = " ".join(shlex.quote(arg) for arg in cmd) + ret, out, err = run_cmd_with_all_output(cmd_str, shell=True) if ret != 0: - raise RuntimeError(f"Failed to run {cmd}: {err}") + raise RuntimeError(f"Failed to run {cmd_str}: {err}") class RetrySignal(Exception): diff --git a/examples/machine/ssh_proxy_command.json b/examples/machine/ssh_proxy_command.json new file mode 100644 index 00000000..abca964b --- /dev/null +++ b/examples/machine/ssh_proxy_command.json @@ -0,0 +1,13 @@ +{ + "batch_type": "Shell", + "context_type": "SSHContext", + "local_root": "./", + "remote_root": "/home/user/work", + "remote_profile": { + "hostname": "internal-server.company.com", + "username": "user", + "port": 22, + "key_filename": "~/.ssh/id_rsa", + "proxy_command": "ssh -W %h:%p -i ~/.ssh/jump_key jumpuser@bastion.company.com" + } +} diff --git a/tests/test_examples.py b/tests/test_examples.py index c5504925..640e8998 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -24,6 +24,7 @@ (machine_args, p_examples / "machine" / "expanse.json"), (machine_args, p_examples / "machine" / "lazy_local.json"), (machine_args, p_examples / "machine" / "mandu.json"), + (machine_args, p_examples / "machine" / "ssh_proxy_command.json"), (resources_args, p_examples / "resources" / "expanse_cpu.json"), (resources_args, p_examples / "resources" / "mandu.json"), (resources_args, p_examples / "resources" / "tiger.json"), diff --git a/tests/test_rsync_proxy.py b/tests/test_rsync_proxy.py new file mode 100644 index 00000000..9ac1d9be --- /dev/null +++ b/tests/test_rsync_proxy.py @@ -0,0 +1,130 @@ +import os +import shutil +import sys +import tempfile +import unittest + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +__package__ = "tests" + +from dpdispatcher.utils.utils import rsync + + +@unittest.skipIf( + os.environ.get("DPDISPATCHER_TEST") != "ssh", "outside the ssh testing environment" +) +class TestRsyncProxyCommand(unittest.TestCase): + """Test rsync function with proxy command support.""" + + def setUp(self): + """Set up test files for rsync operations.""" + # Check if rsync is available before running tests + if shutil.which("rsync") is None: + self.skipTest("rsync not available") + + # Create temporary test files + self.test_content = "test content for rsync" + + # Local test file + self.local_file = tempfile.NamedTemporaryFile(mode="w", delete=False) + self.local_file.write(self.test_content) + self.local_file.close() + + # Remote paths for testing + self.remote_test_dir = "/tmp/rsync_test" + self.remote_file_direct = f"root@server:{self.remote_test_dir}/test_direct.txt" + self.remote_file_proxy = f"root@server:{self.remote_test_dir}/test_proxy.txt" + + def tearDown(self): + """Clean up test files.""" + # Remove local test file + os.unlink(self.local_file.name) + + def test_rsync_with_proxy_command(self): + """Test rsync with proxy command via jump host.""" + # Test rsync through jump host: test -> jumphost -> server + rsync( + self.local_file.name, + self.remote_file_proxy, + key_filename="/root/.ssh/id_rsa", + proxy_command="ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i /root/.ssh/id_rsa -W server:22 root@jumphost", + ) + + # Verify the file was transferred by reading it back + with tempfile.NamedTemporaryFile(mode="w", delete=False) as download_file: + download_path = download_file.name + + rsync( + self.remote_file_proxy, + download_path, + key_filename="/root/.ssh/id_rsa", + proxy_command="ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i /root/.ssh/id_rsa -W server:22 root@jumphost", + ) + + # Verify content matches + with open(download_path) as f: + content = f.read() + self.assertEqual(content, self.test_content) + + # Clean up + os.unlink(download_path) + + def test_rsync_direct_connection(self): + """Test rsync without proxy command (direct connection).""" + # Test direct rsync: test -> server + rsync( + self.local_file.name, + self.remote_file_direct, + key_filename="/root/.ssh/id_rsa", + ) + + # Verify the file was transferred by reading it back + with tempfile.NamedTemporaryFile(mode="w", delete=False) as download_file: + download_path = download_file.name + + rsync(self.remote_file_direct, download_path, key_filename="/root/.ssh/id_rsa") + + # Verify content matches + with open(download_path) as f: + content = f.read() + self.assertEqual(content, self.test_content) + + # Clean up + os.unlink(download_path) + + def test_rsync_with_additional_options(self): + """Test rsync with proxy command and additional SSH options.""" + # Test rsync with custom port, timeout, and proxy + rsync( + self.local_file.name, + self.remote_file_proxy, + port=22, + key_filename="/root/.ssh/id_rsa", + timeout=30, + proxy_command="ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i /root/.ssh/id_rsa -W server:22 root@jumphost", + ) + + # Verify the file exists on remote by attempting to download + with tempfile.NamedTemporaryFile(mode="w", delete=False) as download_file: + download_path = download_file.name + + rsync( + self.remote_file_proxy, + download_path, + port=22, + key_filename="/root/.ssh/id_rsa", + timeout=30, + proxy_command="ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i /root/.ssh/id_rsa -W server:22 root@jumphost", + ) + + # Verify content + with open(download_path) as f: + content = f.read() + self.assertEqual(content, self.test_content) + + # Clean up + os.unlink(download_path) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_ssh_jump_host.py b/tests/test_ssh_jump_host.py new file mode 100644 index 00000000..45225ac4 --- /dev/null +++ b/tests/test_ssh_jump_host.py @@ -0,0 +1,90 @@ +import os +import sys +import unittest + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +__package__ = "tests" +from .context import ( + SSHSession, + setUpModule, # noqa: F401 +) + + +@unittest.skipIf( + os.environ.get("DPDISPATCHER_TEST") != "ssh", "outside the ssh testing environment" +) +class TestSSHJumpHost(unittest.TestCase): + """Test SSH jump host functionality.""" + + def test_proxy_command_connection(self): + """Test SSH connection using proxy_command via jump host.""" + # Test connection from test -> server via jumphost + ssh_session = SSHSession( + hostname="server", + username="root", + key_filename="/root/.ssh/id_rsa", + proxy_command="ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i /root/.ssh/id_rsa -W server:22 root@jumphost", + ) + + # Verify the connection was established + self.assertIsNotNone(ssh_session.ssh) + self.assertTrue(ssh_session._check_alive()) + + # Test running a simple command through the proxy + assert ssh_session.ssh is not None # for type checker + stdin, stdout, stderr = ssh_session.ssh.exec_command("echo 'test via proxy'") + output = stdout.read().decode().strip() + self.assertEqual(output, "test via proxy") + + # Verify proxy_command attribute is set correctly + self.assertEqual( + ssh_session.proxy_command, + "ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i /root/.ssh/id_rsa -W server:22 root@jumphost", + ) + + ssh_session.close() + + def test_direct_connection_no_proxy(self): + """Test direct SSH connection without proxy command.""" + # Test direct connection from test -> server (no proxy) + ssh_session = SSHSession( + hostname="server", username="root", key_filename="/root/.ssh/id_rsa" + ) + + # Verify the connection was established + self.assertIsNotNone(ssh_session.ssh) + self.assertTrue(ssh_session._check_alive()) + + # Test running a simple command + assert ssh_session.ssh is not None # for type checker + stdin, stdout, stderr = ssh_session.ssh.exec_command("echo 'test direct'") + output = stdout.read().decode().strip() + self.assertEqual(output, "test direct") + + # Verify no proxy_command is set + self.assertIsNone(ssh_session.proxy_command) + + ssh_session.close() + + def test_jump_host_direct_connection(self): + """Test direct connection to jump host itself.""" + # Test direct connection from test -> jumphost + ssh_session = SSHSession( + hostname="jumphost", username="root", key_filename="/root/.ssh/id_rsa" + ) + + # Verify the connection was established + self.assertIsNotNone(ssh_session.ssh) + self.assertTrue(ssh_session._check_alive()) + + # Test running a command on jumphost + assert ssh_session.ssh is not None # for type checker + stdin, stdout, stderr = ssh_session.ssh.exec_command("hostname") + output = stdout.read().decode().strip() + self.assertEqual(output, "jumphost") + + ssh_session.close() + + +if __name__ == "__main__": + unittest.main()