Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions labgrid/driver/sshdriver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=no-member
"""The SSHDriver uses SSH as a transport to implement CommandProtocol and FileTransferProtocol"""
import contextlib
import logging
import os
import shutil
Expand All @@ -15,6 +16,7 @@
from .common import Driver
from ..step import step
from .exception import ExecutionError
from ..util.helper import get_free_port
from ..util.proxy import proxymanager
from ..util.timeout import Timeout

Expand Down Expand Up @@ -222,6 +224,76 @@ def interact(self, cmd=None):
)
return sub.wait()

@contextlib.contextmanager
def _forward(self, forward):
cmd = ["ssh", *self.ssh_prefix,
"-O", "forward", forward,
self.networkservice.address
]
self.logger.debug("Running command: %s", cmd)
subprocess.run(cmd, check=True)
try:
yield
finally:
cmd = ["ssh", *self.ssh_prefix,
"-O", "cancel", forward,
self.networkservice.address
]
self.logger.debug("Running command: %s", cmd)
# Master socket may have been cleaned up already, so don't bother
# the user with an error message
subprocess.run(cmd, stderr=subprocess.DEVNULL)

@Driver.check_active
@contextlib.contextmanager
def forward_local_port(self, remoteport, localport=None):
"""Forward a local port to a remote port on the target

A context manager that keeps a local port forwarded to a remote port as
long as the context remains valid. A connection can be made to the
returned port on localhost and it will be forwarded to the remote port
on the target device

usage:
with ssh.forward_local_port(8080) as localport:
# Use localhost:localport here to connect to port 8080 on the
# target

returns:
localport
"""
if not self._check_keepalive():
raise ExecutionError("Keepalive no longer running")

if localport is None:
localport = get_free_port()

forward = "-L%d:localhost:%d" % (localport, remoteport)
with self._forward(forward):
yield localport

@Driver.check_active
@contextlib.contextmanager
def forward_remote_port(self, remoteport, localport):
"""Forward a remote port on the target to a local port

A context manager that keeps a remote port forwarded to a local port as
long as the context remains valid. A connection can be made to the
remote on the target device will be forwarded to the returned local
port on localhost

usage:
with ssh.forward_remote_port(8080, 8081) as localport:
# Connections to port 8080 on the target will be redirected to
# localhost:8081
"""
if not self._check_keepalive():
raise ExecutionError("Keepalive no longer running")

forward = "-R%d:localhost:%d" % (remoteport, localport)
with self._forward(forward):
yield

@Driver.check_active
@step(args=['src', 'dst'])
def scp(self, *, src, dst):
Expand Down
71 changes: 70 additions & 1 deletion labgrid/remote/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
coordinator, acquire a place and interact with the connected resources"""
import argparse
import asyncio
import contextlib
import os
import subprocess
import traceback
import logging
import signal
import sys
from textwrap import indent
from socket import gethostname
Expand Down Expand Up @@ -1087,6 +1089,29 @@ def sshfs(self):

res = drv.sshfs(path=self.args.path, mountpoint=self.args.mountpoint)

def forward(self):
if not self.args.local and not self.args.remote:
print("Nothing to forward")
return

drv = self._get_ssh()

with contextlib.ExitStack() as stack:
for local, remote in self.args.local:
localport = stack.enter_context(drv.forward_local_port(remote, localport=local))
print("Forwarding local port %d to remote port %d" % (localport, remote))

for local, remote in self.args.remote:
stack.enter_context(drv.forward_remote_port(remote, localport))
print("Forwarding remote port %d to local port %d" % (remote, local))

try:
print("Waiting for CTRL+C...")
while True:
signal.pause()
except KeyboardInterrupt:
print("Exiting...")

def telnet(self):
place = self.get_acquired_place()
ip = self._get_ip(place)
Expand Down Expand Up @@ -1328,6 +1353,40 @@ def find_any_role_with_place(config):
return (role, place)
return None, None

class LocalPort(argparse.Action):
def __init__(self, option_strings, dest, nargs=None, **kwargs):
if nargs is not None:
raise ValueError("nargs not allowed")
super().__init__(option_strings, dest, nargs=None, default=[], **kwargs)

def __call__(self, parser, namespace, value, option_string):
if ":" in value:
(local, remote) = value.split(":")
local = int(local)
remote = int(remote)
else:
local = None
remote = int(value)

v = getattr(namespace, self.dest, [])
v.append((local, remote))
setattr(namespace, self.dest, v)

class RemotePort(argparse.Action):
def __init__(self, option_strings, dest, nargs=None, **kwargs):
if nargs is not None:
raise ValueError("nargs not allowed")
super().__init__(option_strings, dest, nargs=None, default=[], **kwargs)

def __call__(self, parser, namespace, value, option_string):
(remote, local) = value.split(":")
remote = int(remote)
local = int(local)

v = getattr(namespace, self.dest, [])
v.append((local, remote))
setattr(namespace, self.dest, v)

def main():
processwrapper.enable_logging()
logging.basicConfig(
Expand Down Expand Up @@ -1568,6 +1627,16 @@ def main():
subparser.add_argument('mountpoint', help='local path')
subparser.set_defaults(func=ClientSession.sshfs)

subparser = subparsers.add_parser('forward',
help="forward local port to remote target")
subparser.add_argument("--local", "-L", metavar="[LOCAL:]REMOTE",
action=LocalPort,
help="Forward local port LOCAL to remote port REMOTE. If LOCAL is unspecified, an arbitrary port will be chosen")
subparser.add_argument("--remote", "-R", metavar="REMOTE:LOCAL",
action=RemotePort,
help="Forward remote port REMOTE to local port LOCAL")
subparser.set_defaults(func=ClientSession.forward)

subparser = subparsers.add_parser('telnet',
help="connect via telnet")
subparser.set_defaults(func=ClientSession.telnet)
Expand Down Expand Up @@ -1647,7 +1716,7 @@ def main():

# make any leftover arguments available for some commands
args, leftover = parser.parse_known_args()
if args.command not in ['ssh', 'rsync']:
if args.command not in ['ssh', 'rsync', 'forward']:
args = parser.parse_args()
else:
args.leftover = leftover
Expand Down
41 changes: 41 additions & 0 deletions tests/test_sshdriver.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest
import socket

from labgrid.driver import SSHDriver, ExecutionError
from labgrid.exceptions import NoResourceFoundError
from labgrid.resource import NetworkService
from labgrid.util.helper import get_free_port

@pytest.fixture(scope='function')
def ssh_driver_mocked_and_activated(target, mocker):
Expand Down Expand Up @@ -131,3 +133,42 @@ def test_local_run_check(ssh_localhost, tmpdir):

res = ssh_localhost.run_check("echo Hello")
assert res == (["Hello"])

@pytest.mark.sshusername
def test_local_port_forward(ssh_localhost, tmpdir):
remoteport = get_free_port()
test_string = "Hello World"

with ssh_localhost.forward_local_port(remoteport) as localport:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as send_socket:
server_socket.bind(("127.0.0.1", remoteport))
server_socket.listen(1)

send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
send_socket.connect(("127.0.0.1", localport))

client_socket, address = server_socket.accept()
send_socket.send(test_string.encode('utf-8'))

assert client_socket.recv(16).decode("utf-8") == test_string

@pytest.mark.sshusername
def test_local_remote_forward(ssh_localhost, tmpdir):
remoteport = get_free_port()
localport = get_free_port()
test_string = "Hello World"

with ssh_localhost.forward_remote_port(remoteport, localport):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_socket:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as send_socket:
server_socket.bind(("127.0.0.1", localport))
server_socket.listen(1)

send_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
send_socket.connect(("127.0.0.1", remoteport))

client_socket, address = server_socket.accept()
send_socket.send(test_string.encode('utf-8'))

assert client_socket.recv(16).decode("utf-8") == test_string