Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ dbconfig.json
*.egg
*.egg-info
venv/*
.venv/*
node_modules/
# Test artifacts
*_flag_if_job_task_fail
Expand Down
12 changes: 12 additions & 0 deletions ci/ssh/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@ services:
- "22"
volumes:
- ssh_config:/root/.ssh
jumphost:
image: takeyamajp/ubuntu-sshd:ubuntu22.04
build: .
container_name: jumphost
hostname: jumphost
environment:
ROOT_PASSWORD: dpdispatcher
expose:
- "22"
volumes:
- ssh_config:/root/.ssh
test:
image: python:3.10
tty: true
Expand All @@ -25,6 +36,7 @@ services:
- ../..:/dpdispatcher
depends_on:
- server
- jumphost

volumes:
ssh_config:
22 changes: 22 additions & 0 deletions ci/ssh/start-ssh.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,27 @@

docker compose up -d --no-build

# Set up SSH keys on server
docker exec server /bin/bash -c "ssh-keygen -b 2048 -t rsa -f /root/.ssh/id_rsa -q -N \"\" && cat /root/.ssh/id_rsa.pub >> /root/.ssh/authorized_keys && chmod 600 /root/.ssh/authorized_keys"
docker exec server /bin/bash -c "mkdir -p /dpdispatcher_working"
docker exec server /bin/bash -c "mkdir -p /tmp/rsync_test"

# Set up SSH keys on jumphost and configure it to access server
docker exec jumphost /bin/bash -c "ssh-keygen -b 2048 -t rsa -f /root/.ssh/id_rsa -q -N \"\" && cat /root/.ssh/id_rsa.pub >> /root/.ssh/authorized_keys && chmod 600 /root/.ssh/authorized_keys"

# Copy keys between containers to enable jump host functionality
# Get the public key from jumphost and add it to server's authorized_keys
docker exec jumphost /bin/bash -c "cat /root/.ssh/id_rsa.pub" | docker exec -i server /bin/bash -c "cat >> /root/.ssh/authorized_keys"

# Get the public key from test (which shares volume with server) and add it to jumphost authorized_keys
docker exec test /bin/bash -c "cat /root/.ssh/id_rsa.pub" | docker exec -i jumphost /bin/bash -c "cat >> /root/.ssh/authorized_keys"

# Configure SSH client settings for known hosts to avoid host key verification
docker exec test /bin/bash -c "echo 'StrictHostKeyChecking no' >> /root/.ssh/config && echo 'UserKnownHostsFile /dev/null' >> /root/.ssh/config"
docker exec jumphost /bin/bash -c "echo 'StrictHostKeyChecking no' >> /root/.ssh/config && echo 'UserKnownHostsFile /dev/null' >> /root/.ssh/config"
docker exec server /bin/bash -c "echo 'StrictHostKeyChecking no' >> /root/.ssh/config && echo 'UserKnownHostsFile /dev/null' >> /root/.ssh/config"

# Install rsync on all containers
docker exec test /bin/bash -c "apt-get -y update && apt-get -y install rsync"
docker exec jumphost /bin/bash -c "apt-get -y update && apt-get -y install rsync"
docker exec server /bin/bash -c "apt-get -y update && apt-get -y install rsync"
22 changes: 22 additions & 0 deletions doc/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,28 @@ It's suggested to generate [SSH keys](https://help.ubuntu.com/community/SSH/Open

Note that `SSH` context is [non-login](https://www.gnu.org/software/bash/manual/html_node/Bash-Startup-Files.html), so `bash_profile` files will not be executed outside the submission script.

### SSH Jump Host (Bastion Server)

For connecting to internal servers through a jump host (bastion server), SSH context supports jump host configuration. This allows connecting to internal servers that are not directly accessible from the internet.

Specify the ProxyCommand directly using {dargs:argument}`proxy_command <machine[SSHContext]/remote_profile/proxy_command>`:

```json
{
"context_type": "SSHContext",
"remote_profile": {
"hostname": "internal-server.company.com",
"username": "user",
"key_filename": "/path/to/internal_key",
"proxy_command": "ssh -W %h:%p -i /path/to/jump_key jumpuser@bastion.company.com"
}
}
```

The proxy command uses OpenSSH ProxyCommand syntax. `%h` and `%p` are replaced with the target hostname and port.

This configuration establishes the connection path: Local → Jump Host → Target Server.

## Bohrium

{dargs:argument}`context_type <machine/context_type>`: `Bohrium`
Expand Down
33 changes: 32 additions & 1 deletion dpdispatcher/contexts/ssh_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
tar_compress=True,
look_for_keys=True,
execute_command=None,
proxy_command=None,
):
self.hostname = hostname
self.username = username
Expand All @@ -58,6 +59,7 @@ def __init__(
self.tar_compress = tar_compress
self.look_for_keys = look_for_keys
self.execute_command = execute_command
self.proxy_command = proxy_command
self._keyboard_interactive_auth = False
self._setup_ssh()

Expand Down Expand Up @@ -142,7 +144,12 @@ def _setup_ssh(self):
# transport = self.ssh.get_transport()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(self.timeout)
sock.connect((self.hostname, self.port))

# Use ProxyCommand if configured (either directly or via jump host parameters)
if self.proxy_command is not None:
sock = paramiko.ProxyCommand(self.proxy_command)
else:
sock.connect((self.hostname, self.port))

# Make a Paramiko Transport object using the socket
ts = paramiko.Transport(sock)
Expand Down Expand Up @@ -340,6 +347,9 @@ def arginfo():
"enable searching for discoverable private key files in ~/.ssh/"
)
doc_execute_command = "execute command after ssh connection is established."
doc_proxy_command = (
"ProxyCommand to use for SSH connection through intermediate servers."
)
ssh_remote_profile_args = [
Argument("hostname", str, optional=False, doc=doc_hostname),
Argument("username", str, optional=False, doc=doc_username),
Expand Down Expand Up @@ -388,6 +398,13 @@ def arginfo():
default=None,
doc=doc_execute_command,
),
Argument(
"proxy_command",
[str, type(None)],
optional=True,
default=None,
doc=doc_proxy_command,
),
]
ssh_remote_profile_format = Argument(
"ssh_session", dict, ssh_remote_profile_args
Expand All @@ -396,23 +413,37 @@ def arginfo():

def put(self, from_f, to_f):
if self.rsync_available:
# For rsync, we need to use %h:%p placeholders for target host/port
proxy_cmd_rsync = None
if self.proxy_command is not None:
proxy_cmd_rsync = self.proxy_command.replace(
f"{self.hostname}:{self.port}", "%h:%p"
)
return rsync(
from_f,
self.remote + ":" + to_f,
port=self.port,
key_filename=self.key_filename,
timeout=self.timeout,
proxy_command=proxy_cmd_rsync,
)
return self.sftp.put(from_f, to_f)

def get(self, from_f, to_f):
if self.rsync_available:
# For rsync, we need to use %h:%p placeholders for target host/port
proxy_cmd_rsync = None
if self.proxy_command is not None:
proxy_cmd_rsync = self.proxy_command.replace(
f"{self.hostname}:{self.port}", "%h:%p"
)
return rsync(
self.remote + ":" + from_f,
to_f,
port=self.port,
key_filename=self.key_filename,
timeout=self.timeout,
proxy_command=proxy_cmd_rsync,
)
return self.sftp.get(from_f, to_f)

Expand Down
20 changes: 17 additions & 3 deletions dpdispatcher/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import hmac
import os
import shlex
import struct
import subprocess
import time
Expand Down Expand Up @@ -89,6 +90,7 @@ def rsync(
port: int = 22,
key_filename: Optional[str] = None,
timeout: Union[int, float] = 10,
proxy_command: Optional[str] = None,
):
"""Call rsync to transfer files.

Expand All @@ -104,6 +106,8 @@ def rsync(
identity file name
timeout : int, default=10
timeout for ssh
proxy_command : str, optional
ProxyCommand to use for SSH connection

Raises
------
Expand All @@ -124,20 +128,30 @@ def rsync(
]
if key_filename is not None:
ssh_cmd.extend(["-i", key_filename])

# Use proxy_command if provided
if proxy_command is not None:
ssh_cmd.extend(["-o", f"ProxyCommand={proxy_command}"])

# Properly escape the SSH command for rsync's -e option
ssh_cmd_str = " ".join(shlex.quote(part) for part in ssh_cmd)

cmd = [
"rsync",
# -a: archieve
# -z: compress
"-az",
"-e",
" ".join(ssh_cmd),
ssh_cmd_str,
"-q",
from_file,
to_file,
]
ret, out, err = run_cmd_with_all_output(cmd, shell=False)
# Convert to string for shell=True
cmd_str = " ".join(shlex.quote(arg) for arg in cmd)
ret, out, err = run_cmd_with_all_output(cmd_str, shell=True)
if ret != 0:
raise RuntimeError(f"Failed to run {cmd}: {err}")
raise RuntimeError(f"Failed to run {cmd_str}: {err}")


class RetrySignal(Exception):
Expand Down
13 changes: 13 additions & 0 deletions examples/machine/ssh_proxy_command.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"batch_type": "Shell",
"context_type": "SSHContext",
"local_root": "./",
"remote_root": "/home/user/work",
"remote_profile": {
"hostname": "internal-server.company.com",
"username": "user",
"port": 22,
"key_filename": "~/.ssh/id_rsa",
"proxy_command": "ssh -W %h:%p -i ~/.ssh/jump_key jumpuser@bastion.company.com"
}
}
1 change: 1 addition & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
(machine_args, p_examples / "machine" / "expanse.json"),
(machine_args, p_examples / "machine" / "lazy_local.json"),
(machine_args, p_examples / "machine" / "mandu.json"),
(machine_args, p_examples / "machine" / "ssh_proxy_command.json"),
(resources_args, p_examples / "resources" / "expanse_cpu.json"),
(resources_args, p_examples / "resources" / "mandu.json"),
(resources_args, p_examples / "resources" / "tiger.json"),
Expand Down
130 changes: 130 additions & 0 deletions tests/test_rsync_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os
import shutil
import sys
import tempfile
import unittest

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
__package__ = "tests"

from dpdispatcher.utils.utils import rsync


@unittest.skipIf(
os.environ.get("DPDISPATCHER_TEST") != "ssh", "outside the ssh testing environment"
)
class TestRsyncProxyCommand(unittest.TestCase):
"""Test rsync function with proxy command support."""

def setUp(self):
"""Set up test files for rsync operations."""
# Check if rsync is available before running tests
if shutil.which("rsync") is None:
self.skipTest("rsync not available")

# Create temporary test files
self.test_content = "test content for rsync"

# Local test file
self.local_file = tempfile.NamedTemporaryFile(mode="w", delete=False)
self.local_file.write(self.test_content)
self.local_file.close()

# Remote paths for testing
self.remote_test_dir = "/tmp/rsync_test"
self.remote_file_direct = f"root@server:{self.remote_test_dir}/test_direct.txt"
self.remote_file_proxy = f"root@server:{self.remote_test_dir}/test_proxy.txt"

def tearDown(self):
"""Clean up test files."""
# Remove local test file
os.unlink(self.local_file.name)

def test_rsync_with_proxy_command(self):
"""Test rsync with proxy command via jump host."""
# Test rsync through jump host: test -> jumphost -> server
rsync(
self.local_file.name,
self.remote_file_proxy,
key_filename="/root/.ssh/id_rsa",
proxy_command="ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i /root/.ssh/id_rsa -W server:22 root@jumphost",
)

# Verify the file was transferred by reading it back
with tempfile.NamedTemporaryFile(mode="w", delete=False) as download_file:
download_path = download_file.name

rsync(
self.remote_file_proxy,
download_path,
key_filename="/root/.ssh/id_rsa",
proxy_command="ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i /root/.ssh/id_rsa -W server:22 root@jumphost",
)

# Verify content matches
with open(download_path) as f:
content = f.read()
self.assertEqual(content, self.test_content)

# Clean up
os.unlink(download_path)

def test_rsync_direct_connection(self):
"""Test rsync without proxy command (direct connection)."""
# Test direct rsync: test -> server
rsync(
self.local_file.name,
self.remote_file_direct,
key_filename="/root/.ssh/id_rsa",
)

# Verify the file was transferred by reading it back
with tempfile.NamedTemporaryFile(mode="w", delete=False) as download_file:
download_path = download_file.name

rsync(self.remote_file_direct, download_path, key_filename="/root/.ssh/id_rsa")

# Verify content matches
with open(download_path) as f:
content = f.read()
self.assertEqual(content, self.test_content)

# Clean up
os.unlink(download_path)

def test_rsync_with_additional_options(self):
"""Test rsync with proxy command and additional SSH options."""
# Test rsync with custom port, timeout, and proxy
rsync(
self.local_file.name,
self.remote_file_proxy,
port=22,
key_filename="/root/.ssh/id_rsa",
timeout=30,
proxy_command="ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i /root/.ssh/id_rsa -W server:22 root@jumphost",
)

# Verify the file exists on remote by attempting to download
with tempfile.NamedTemporaryFile(mode="w", delete=False) as download_file:
download_path = download_file.name

rsync(
self.remote_file_proxy,
download_path,
port=22,
key_filename="/root/.ssh/id_rsa",
timeout=30,
proxy_command="ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -i /root/.ssh/id_rsa -W server:22 root@jumphost",
)

# Verify content
with open(download_path) as f:
content = f.read()
self.assertEqual(content, self.test_content)

# Clean up
os.unlink(download_path)


if __name__ == "__main__":
unittest.main()
Loading
Loading