From ac2ab9905df99cd602493018d8f435abba0635ee Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Fri, 19 Mar 2021 21:03:46 -0500 Subject: [PATCH 1/2] sshdriver: Add port forwarding Adds a context manager function that will forward connections to a port on the local host to a port on the target, and target ports to a port on the local host Signed-off-by: Joshua Watt --- labgrid/driver/sshdriver.py | 72 +++++++++++++++++++++++++++++++++++++ tests/test_sshdriver.py | 41 +++++++++++++++++++++ 2 files changed, 113 insertions(+) diff --git a/labgrid/driver/sshdriver.py b/labgrid/driver/sshdriver.py index 0dd31103e..cbf789097 100644 --- a/labgrid/driver/sshdriver.py +++ b/labgrid/driver/sshdriver.py @@ -1,5 +1,6 @@ # pylint: disable=no-member """The SSHDriver uses SSH as a transport to implement CommandProtocol and FileTransferProtocol""" +import contextlib import logging import os import shutil @@ -15,6 +16,7 @@ from .common import Driver from ..step import step from .exception import ExecutionError +from ..util.helper import get_free_port from ..util.proxy import proxymanager from ..util.timeout import Timeout @@ -222,6 +224,76 @@ def interact(self, cmd=None): ) return sub.wait() + @contextlib.contextmanager + def _forward(self, forward): + cmd = ["ssh", *self.ssh_prefix, + "-O", "forward", forward, + self.networkservice.address + ] + self.logger.debug("Running command: %s", cmd) + subprocess.run(cmd, check=True) + try: + yield + finally: + cmd = ["ssh", *self.ssh_prefix, + "-O", "cancel", forward, + self.networkservice.address + ] + self.logger.debug("Running command: %s", cmd) + # Master socket may have been cleaned up already, so don't bother + # the user with an error message + subprocess.run(cmd, stderr=subprocess.DEVNULL) + + @Driver.check_active + @contextlib.contextmanager + def forward_local_port(self, remoteport, localport=None): + """Forward a local port to a remote port on the target + + A context manager that keeps a local port forwarded to a remote port as + long as the context remains valid. A connection can be made to the + returned port on localhost and it will be forwarded to the remote port + on the target device + + usage: + with ssh.forward_local_port(8080) as localport: + # Use localhost:localport here to connect to port 8080 on the + # target + + returns: + localport + """ + if not self._check_keepalive(): + raise ExecutionError("Keepalive no longer running") + + if localport is None: + localport = get_free_port() + + forward = "-L%d:localhost:%d" % (localport, remoteport) + with self._forward(forward): + yield localport + + @Driver.check_active + @contextlib.contextmanager + def forward_remote_port(self, remoteport, localport): + """Forward a remote port on the target to a local port + + A context manager that keeps a remote port forwarded to a local port as + long as the context remains valid. A connection can be made to the + remote on the target device will be forwarded to the returned local + port on localhost + + usage: + with ssh.forward_remote_port(8080, 8081) as localport: + # Connections to port 8080 on the target will be redirected to + # localhost:8081 + """ + if not self._check_keepalive(): + raise ExecutionError("Keepalive no longer running") + + forward = "-R%d:localhost:%d" % (remoteport, localport) + with self._forward(forward): + yield + @Driver.check_active @step(args=['src', 'dst']) def scp(self, *, src, dst): diff --git a/tests/test_sshdriver.py b/tests/test_sshdriver.py index 59dca6d2c..875570822 100644 --- a/tests/test_sshdriver.py +++ b/tests/test_sshdriver.py @@ -1,8 +1,10 @@ import pytest +import socket from labgrid.driver import SSHDriver, ExecutionError from labgrid.exceptions import NoResourceFoundError from labgrid.resource import NetworkService +from labgrid.util.helper import get_free_port @pytest.fixture(scope='function') def ssh_driver_mocked_and_activated(target, mocker): @@ -131,3 +133,42 @@ def test_local_run_check(ssh_localhost, tmpdir): res = ssh_localhost.run_check("echo Hello") assert res == (["Hello"]) + +@pytest.mark.sshusername +def test_local_port_forward(ssh_localhost, tmpdir): + remoteport = get_free_port() + test_string = "Hello World" + + with ssh_localhost.forward_local_port(remoteport) as localport: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as send_socket: + server_socket.bind(("127.0.0.1", remoteport)) + server_socket.listen(1) + + send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + send_socket.connect(("127.0.0.1", localport)) + + client_socket, address = server_socket.accept() + send_socket.send(test_string.encode('utf-8')) + + assert client_socket.recv(16).decode("utf-8") == test_string + +@pytest.mark.sshusername +def test_local_remote_forward(ssh_localhost, tmpdir): + remoteport = get_free_port() + localport = get_free_port() + test_string = "Hello World" + + with ssh_localhost.forward_remote_port(remoteport, localport): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as send_socket: + server_socket.bind(("127.0.0.1", localport)) + server_socket.listen(1) + + send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + send_socket.connect(("127.0.0.1", remoteport)) + + client_socket, address = server_socket.accept() + send_socket.send(test_string.encode('utf-8')) + + assert client_socket.recv(16).decode("utf-8") == test_string From 51f3fe5db75d23fe51e405231e9fcc88a13f9a22 Mon Sep 17 00:00:00 2001 From: Joshua Watt Date: Mon, 22 Mar 2021 11:14:12 -0500 Subject: [PATCH 2/2] client: Add port forwarding Adds a subcommand to forward ports between localhost and the target Signed-off-by: Joshua Watt --- labgrid/remote/client.py | 71 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/labgrid/remote/client.py b/labgrid/remote/client.py index ae8e7ce29..7d857ffe3 100755 --- a/labgrid/remote/client.py +++ b/labgrid/remote/client.py @@ -2,10 +2,12 @@ coordinator, acquire a place and interact with the connected resources""" import argparse import asyncio +import contextlib import os import subprocess import traceback import logging +import signal import sys from textwrap import indent from socket import gethostname @@ -1087,6 +1089,29 @@ def sshfs(self): res = drv.sshfs(path=self.args.path, mountpoint=self.args.mountpoint) + def forward(self): + if not self.args.local and not self.args.remote: + print("Nothing to forward") + return + + drv = self._get_ssh() + + with contextlib.ExitStack() as stack: + for local, remote in self.args.local: + localport = stack.enter_context(drv.forward_local_port(remote, localport=local)) + print("Forwarding local port %d to remote port %d" % (localport, remote)) + + for local, remote in self.args.remote: + stack.enter_context(drv.forward_remote_port(remote, localport)) + print("Forwarding remote port %d to local port %d" % (remote, local)) + + try: + print("Waiting for CTRL+C...") + while True: + signal.pause() + except KeyboardInterrupt: + print("Exiting...") + def telnet(self): place = self.get_acquired_place() ip = self._get_ip(place) @@ -1328,6 +1353,40 @@ def find_any_role_with_place(config): return (role, place) return None, None +class LocalPort(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super().__init__(option_strings, dest, nargs=None, default=[], **kwargs) + + def __call__(self, parser, namespace, value, option_string): + if ":" in value: + (local, remote) = value.split(":") + local = int(local) + remote = int(remote) + else: + local = None + remote = int(value) + + v = getattr(namespace, self.dest, []) + v.append((local, remote)) + setattr(namespace, self.dest, v) + +class RemotePort(argparse.Action): + def __init__(self, option_strings, dest, nargs=None, **kwargs): + if nargs is not None: + raise ValueError("nargs not allowed") + super().__init__(option_strings, dest, nargs=None, default=[], **kwargs) + + def __call__(self, parser, namespace, value, option_string): + (remote, local) = value.split(":") + remote = int(remote) + local = int(local) + + v = getattr(namespace, self.dest, []) + v.append((local, remote)) + setattr(namespace, self.dest, v) + def main(): processwrapper.enable_logging() logging.basicConfig( @@ -1568,6 +1627,16 @@ def main(): subparser.add_argument('mountpoint', help='local path') subparser.set_defaults(func=ClientSession.sshfs) + subparser = subparsers.add_parser('forward', + help="forward local port to remote target") + subparser.add_argument("--local", "-L", metavar="[LOCAL:]REMOTE", + action=LocalPort, + help="Forward local port LOCAL to remote port REMOTE. If LOCAL is unspecified, an arbitrary port will be chosen") + subparser.add_argument("--remote", "-R", metavar="REMOTE:LOCAL", + action=RemotePort, + help="Forward remote port REMOTE to local port LOCAL") + subparser.set_defaults(func=ClientSession.forward) + subparser = subparsers.add_parser('telnet', help="connect via telnet") subparser.set_defaults(func=ClientSession.telnet) @@ -1647,7 +1716,7 @@ def main(): # make any leftover arguments available for some commands args, leftover = parser.parse_known_args() - if args.command not in ['ssh', 'rsync']: + if args.command not in ['ssh', 'rsync', 'forward']: args = parser.parse_args() else: args.leftover = leftover