Skip to content

Commit

Permalink
Merge pull request #1305 from nlabriet/sshdriver-credentials
Browse files Browse the repository at this point in the history
driver/sshdriver: Add credential overriding NetworkService's
  • Loading branch information
Bastian-Krause committed Dec 15, 2023
2 parents 3b1df77 + 0d8b82b commit 6f3a32b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
2 changes: 2 additions & 0 deletions doc/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,8 @@ Arguments:
explicitly use the SFTP protocol for file transfers instead of scp's default protocol
- explicit_scp_mode (bool, default=False): if set to True, `put()`, `get()`, and `scp()` will
explicitly use the SCP protocol for file transfers instead of scp's default protocol
- username (str, default=username from `NetworkService`): username used by SSH
- password (str, default=password from `NetworkService`): password used by SSH

UBootDriver
~~~~~~~~~~~
Expand Down
28 changes: 19 additions & 9 deletions labgrid/driver/sshdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,29 @@ class SSHDriver(CommandMixin, Driver, CommandProtocol, FileTransferProtocol):
connection_timeout = attr.ib(default=float(get_ssh_connect_timeout()), validator=attr.validators.instance_of(float))
explicit_sftp_mode = attr.ib(default=False, validator=attr.validators.instance_of(bool))
explicit_scp_mode = attr.ib(default=False, validator=attr.validators.instance_of(bool))
username = attr.ib(default="", validator=attr.validators.instance_of(str))
password = attr.ib(default="", validator=attr.validators.instance_of(str))

def __attrs_post_init__(self):
super().__attrs_post_init__()
self._keepalive = None

def _get_username(self):
"""Get the username from this class or from NetworkService"""
return self.username or self.networkservice.username

def _get_password(self):
"""Get the password from this class or from NetworkService"""
return self.password or self.networkservice.password

def on_activate(self):
self.ssh_prefix = ["-o", "LogLevel=ERROR"]
if self.keyfile:
keyfile_path = self.keyfile
if self.target.env:
keyfile_path = self.target.env.config.resolve_path(self.keyfile)
self.ssh_prefix += ["-i", keyfile_path ]
if not self.networkservice.password:
if not self._get_password():
self.ssh_prefix += ["-o", "PasswordAuthentication=no"]

self.control = self._start_own_master()
Expand Down Expand Up @@ -99,7 +109,7 @@ def _start_own_master_once(self, timeout):
"-o", "ControlPersist=300", "-o",
"UserKnownHostsFile=/dev/null", "-o", "StrictHostKeyChecking=no",
"-o", "ServerAliveInterval=15", "-MN", "-S", control.replace('%', '%%'), "-p",
str(self.networkservice.port), "-l", self.networkservice.username,
str(self.networkservice.port), "-l", self._get_username(),
self.networkservice.address]

# proxy via the exporter if we have an ifname suffix
Expand All @@ -119,14 +129,14 @@ def _start_own_master_once(self, timeout):

env = os.environ.copy()
pass_file = ''
if self.networkservice.password:
if self._get_password():
fd, pass_file = tempfile.mkstemp()
os.fchmod(fd, stat.S_IRWXU)
#with openssh>=8.4 SSH_ASKPASS_REQUIRE can be used to force SSH_ASK_PASS
#openssh<8.4 requires the DISPLAY var and a detached process with start_new_session=True
env = {'SSH_ASKPASS': pass_file, 'DISPLAY':'', 'SSH_ASKPASS_REQUIRE':'force'}
with open(fd, 'w') as f:
f.write("#!/bin/sh\necho " + shlex.quote(self.networkservice.password))
f.write("#!/bin/sh\necho " + shlex.quote(self._get_password()))

self.process = subprocess.Popen(args, env=env,
stdout=subprocess.PIPE,
Expand Down Expand Up @@ -163,7 +173,7 @@ def _start_own_master_once(self, timeout):
f"Subprocess timed out [{subprocess_timeout}s] while executing {args}",
)
finally:
if self.networkservice.password and os.path.exists(pass_file):
if self._get_password() and os.path.exists(pass_file):
os.remove(pass_file)

if not os.path.exists(control):
Expand Down Expand Up @@ -194,7 +204,7 @@ def _run(self, cmd, codec="utf-8", decodeerrors="strict", timeout=None):
raise ExecutionError("Keepalive no longer running")

complete_cmd = ["ssh", "-x", *self.ssh_prefix,
"-p", str(self.networkservice.port), "-l", self.networkservice.username,
"-p", str(self.networkservice.port), "-l", self._get_username(),
self.networkservice.address
] + cmd.split(" ")
self.logger.debug("Sending command: %s", complete_cmd)
Expand Down Expand Up @@ -467,7 +477,7 @@ def put(self, filename, remotepath=''):
"-P", str(self.networkservice.port),
"-r",
filename,
f"{self.networkservice.username}@{self.networkservice.address}:{remotepath}"
f"{self._get_username()}@{self.networkservice.address}:{remotepath}"
]

if self.explicit_sftp_mode and self._scp_supports_explicit_sftp_mode():
Expand Down Expand Up @@ -496,7 +506,7 @@ def get(self, filename, destination="."):
*self.ssh_prefix,
"-P", str(self.networkservice.port),
"-r",
f"{self.networkservice.username}@{self.networkservice.address}:{filename}",
f"{self._get_username()}@{self.networkservice.address}:{filename}",
destination
]

Expand All @@ -520,7 +530,7 @@ def get(self, filename, destination="."):

def _cleanup_own_master(self):
"""Exit the controlmaster and delete the tmpdir"""
complete_cmd = f"ssh -x -o ControlPath={self.control.replace('%', '%%')} -O exit -p {self.networkservice.port} -l {self.networkservice.username} {self.networkservice.address}".split(' ') # pylint: disable=line-too-long
complete_cmd = f"ssh -x -o ControlPath={self.control.replace('%', '%%')} -O exit -p {self.networkservice.port} -l {self._get_username()} {self.networkservice.address}".split(' ') # pylint: disable=line-too-long
res = subprocess.call(
complete_cmd,
stdin=subprocess.DEVNULL,
Expand Down

0 comments on commit 6f3a32b

Please sign in to comment.