Skip to content

Commit

Permalink
Merge pull request #690 from christian-monch/consistent_shell_command…
Browse files Browse the repository at this point in the history
…_quoting

Consistent shell command quoting
  • Loading branch information
mih committed May 17, 2024
2 parents 37246b1 + 5d5ee93 commit 3691e63
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 32 deletions.
35 changes: 27 additions & 8 deletions datalad_next/shell/operations/posix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
PurePosixPath,
)
from queue import Queue
from shlex import quote as posix_quote
from typing import (
BinaryIO,
Callable,
Expand Down Expand Up @@ -42,6 +43,19 @@ def get_final_command(self, remote_file_name: bytes) -> bytes:
This method is usually only called by
:meth:`ShellCommandExecutor.__call__`.
Parameters
----------
remote_file_name : bytes
The name of the file that should be downloaded. If the file name
contains special character, e.g. space or ``$``, it must be
quoted for a POSIX shell, for example with ``shlex.quote``.
Returns
-------
bytes
The final command that will be executed in the persistent shell
in order to start the download in the connected shell.
"""
command = b"""
test -r {remote_file_name}
Expand Down Expand Up @@ -79,9 +93,9 @@ def upload(
shell : ShellCommandExecutor
The shell that should be used to upload the file.
local_path : Path
The file that should be uploaded.
The path of the file that should be uploaded.
remote_path : PurePosixPath
The name of the file on the connected shell that will contain the
The path of the file on the connected shell that will contain the
uploaded content.
progress_callback : callable[[int, int], None], optional, default: None
If given, the callback is called with the number of bytes that have
Expand Down Expand Up @@ -144,7 +158,7 @@ def signaling_read(
# `rm -rf $HOME`.
file_size = local_path.stat().st_size
cmd_line = (
f'head -c {file_size} > "{remote_path.as_posix()}" '
f'head -c {file_size} > {posix_quote(str(remote_path))}'
f"|| (head -c {file_size} > /dev/null; test 1 == 2)"
)
with local_path.open("rb") as local_file:
Expand Down Expand Up @@ -195,18 +209,18 @@ def download(
shell: ShellCommandExecutor
The shell from which a file should be downloaded.
remote_path : PurePosixPath
The name of the file on the connected shell that should be
The path of the file on the connected shell that should be
downloaded.
local_path : Path
The name of the local file that will contain the downloaded content.
The path of the local file that will contain the downloaded content.
progress_callback : callable[[int, int], None], optional, default: None
If given, the callback is called with the number of bytes that have
been received and the total number of bytes that should be received.
response_generator_class : type[DownloadResponseGenerator], optional, default: DownloadResponseGeneratorPosix
The response generator that should be used to handle the download
output. It must be a subclass of :class:`DownloadResponseGenerator`.
The default works if the connected shell runs on a Unix-like system that
provides `ls -dln` and `awk`, e.g. ``Linux`` or ``OSX``.
provides `ls -dln`, `cat`, `echo`, and `awk`, e.g. ``Linux`` or ``OSX``.
check : bool, optional, default: False
If ``True``, raise a :class:`CommandError` if the remote operation does
not exit with a ``0`` as return code.
Expand All @@ -225,7 +239,7 @@ def download(
``chunk_size`` keyword argument to :func:`shell`)) bytes of stderr
output.
"""
command = remote_path.as_posix().encode()
command = posix_quote(str(remote_path)).encode()
response_generator = response_generator_class(shell.stdout)
result_generator = shell.start(
command,
Expand Down Expand Up @@ -288,7 +302,12 @@ def delete(
output.
"""
cmd_line = (
"rm " + ("-f " if force else "") + " ".join(f"{f.as_posix()}" for f in files)
"rm "
+ ("-f " if force else "")
+ " ".join(
f"{posix_quote(str(f))}"
for f in files
)
)
result = shell(cmd_line.encode())
if check:
Expand Down
1 change: 0 additions & 1 deletion datalad_next/shell/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,6 @@ def start(self,
self.process_inputs.put(stdin)
return response_generator


def __repr__(self):
return f'{self.__class__.__name__}({self.shell_cmd!r})'

Expand Down
64 changes: 41 additions & 23 deletions datalad_next/shell/tests/test_shell.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import os
import sys
from pathlib import PurePosixPath
from json import loads
from shlex import quote as posix_quote

import pytest
from more_itertools import consume
Expand Down Expand Up @@ -34,6 +35,21 @@
# Some files that are usually found on POSIX systems, i.e. Linux, OSX
common_files = [b'/etc/passwd', b'/etc/shells']

# Select "challenging" file names that need proper quoting based on Windows, on
# POSIX, and on FAT file systems.
if os.getenv('TMPDIR', '').startswith('/crippledfs'):
upload_file_name = "up 1"
download_file_name = "down 1"
files_to_delete = ('f 1', 'f 2', 'f 3')
elif on_windows:
upload_file_name = "upload 123"
download_file_name = "download 123"
files_to_delete = ('f 1', 'f 2', 'f 3')
else:
upload_file_name = "upload $123 \"'"
download_file_name = "download $123 \" ' "
files_to_delete = ('f $1', 'f \\2 " ', 'f 3 \' ')


def _get_cmdline(ssh_url: str):
args, parsed = ssh_url2openargs(ssh_url, datalad.cfg)
Expand All @@ -56,10 +72,15 @@ def test_basic_functionality_multi(sshserver):
_check_ls_result(ssh_executor, file_name)


def _quote_file_name(file_name: bytes, *, encoding: str = 'utf-8') -> bytes:
return posix_quote(file_name.decode(encoding)).encode(encoding)


def _check_ls_result(ssh_executor, file_name: bytes):
result = ssh_executor(b'ls ' + file_name)
quoted_file_name = _quote_file_name(file_name)
result = ssh_executor(b'ls ' + quoted_file_name)
assert result.stdout == file_name + b'\n'
result = ssh_executor('ls ' + file_name.decode())
result = ssh_executor('ls ' + quoted_file_name.decode())
assert result.stdout == file_name + b'\n'


Expand Down Expand Up @@ -143,8 +164,7 @@ def test_upload(sshserver, tmp_path):
ssh_url, local_path = sshserver
ssh_args, ssh_path = _get_cmdline(ssh_url)
content = '0123456789'
test_file_name = 'upload_123'
upload_file = tmp_path / test_file_name
upload_file = tmp_path / upload_file_name
upload_file.write_text(content)
progress = []
with shell(ssh_args) as ssh_executor:
Expand All @@ -155,11 +175,11 @@ def test_upload(sshserver, tmp_path):
result = posix.upload(
ssh_executor,
upload_file,
PurePosixPath(ssh_path + '/' + test_file_name),
progress_callback=lambda a,b: progress.append((a,b))
PurePosixPath(ssh_path + '/' + upload_file_name),
progress_callback=lambda a, b: progress.append((a, b))
)
assert result.returncode == 0
assert (local_path / test_file_name).read_text() == content
assert (local_path / upload_file_name).read_text() == content
assert len(progress) > 0

# perform another operation on the remote shell to ensure functionality
Expand All @@ -170,10 +190,9 @@ def test_download_ssh(sshserver, tmp_path):
ssh_url, local_path = sshserver
ssh_args, ssh_path = _get_cmdline(ssh_url)
content = '0123456789'
test_file_name = 'download_123'
server_file = local_path / test_file_name
server_file = local_path / download_file_name
server_file.write_text(content)
download_file = tmp_path / test_file_name
download_file = tmp_path / download_file_name
progress = []
with shell(ssh_args) as ssh_executor:
# perform an operation on the remote shell
Expand All @@ -182,9 +201,9 @@ def test_download_ssh(sshserver, tmp_path):
# download file from server and verify its content
result = posix.download(
ssh_executor,
PurePosixPath(ssh_path + '/' + test_file_name),
PurePosixPath(ssh_path + '/' + download_file_name),
download_file,
progress_callback=lambda a,b: progress.append((a,b))
progress_callback=lambda a, b: progress.append((a, b))
)
assert result.returncode == 0
assert download_file.read_text() == content
Expand All @@ -199,9 +218,9 @@ def test_download_ssh(sshserver, tmp_path):
@skip_if(on_windows)
def test_download_local_bash(tmp_path):
content = '0123456789'
download_file = tmp_path / 'download_123'
download_file = tmp_path / download_file_name
download_file.write_text(content)
result_file = tmp_path / 'result_123'
result_file = tmp_path / ('result' + download_file_name)
progress = []
with shell(['bash']) as bash:
_check_ls_result(bash, common_files[0])
Expand All @@ -211,7 +230,7 @@ def test_download_local_bash(tmp_path):
bash,
PurePosixPath(download_file),
result_file,
progress_callback=lambda a,b: progress.append((a,b)),
progress_callback=lambda a, b: progress.append((a, b)),
)
assert result_file.read_text() == content
assert len(progress) > 0
Expand All @@ -224,9 +243,9 @@ def test_download_local_bash(tmp_path):
@skip_if(on_windows)
def test_upload_local_bash(tmp_path):
content = '0123456789'
upload_file = tmp_path / 'upload_123'
upload_file = tmp_path / upload_file_name
upload_file.write_text(content)
result_file = tmp_path / 'result_123'
result_file = tmp_path / ('result' + upload_file_name)
progress = []
with shell(['bash']) as bash:
_check_ls_result(bash, common_files[0])
Expand All @@ -236,7 +255,7 @@ def test_upload_local_bash(tmp_path):
bash,
upload_file,
PurePosixPath(result_file),
progress_callback=lambda a,b: progress.append((a,b)),
progress_callback=lambda a, b: progress.append((a, b)),
)
assert result_file.read_text() == content
assert len(progress) > 0
Expand All @@ -261,7 +280,7 @@ def test_upload_local_bash_error(tmp_path):
bash,
source_file,
destination_file,
progress_callback=lambda a,b: progress.append((a,b)),
progress_callback=lambda a, b: progress.append((a, b)),
)
assert result.returncode != 0
assert len(progress) > 0
Expand All @@ -280,7 +299,6 @@ def test_delete(sshserver):
ssh_url, local_path = sshserver
ssh_args, ssh_path = _get_cmdline(ssh_url)

files_to_delete = ('f1', 'f2', 'f3')
with shell(ssh_args) as ssh_executor:
for file in files_to_delete:
(local_path / file).write_text(f'content_{file}')
Expand Down Expand Up @@ -508,7 +526,7 @@ def test_download_error(tmp_path):
bash,
PurePosixPath('/thisdoesnotexist'),
tmp_path / 'downloaded_file',
progress_callback=lambda a,b: progress.append((a,b)),
progress_callback=lambda a, b: progress.append((a, b)),
check=True,
)
_check_ls_result(bash, common_files[0])
Expand All @@ -517,7 +535,7 @@ def test_download_error(tmp_path):
bash,
PurePosixPath('/thisdoesnotexist'),
tmp_path / 'downloaded_file',
progress_callback=lambda a,b: progress.append((a,b)),
progress_callback=lambda a, b: progress.append((a, b)),
check=False,
)
assert result.returncode not in (0, None)
Expand Down

0 comments on commit 3691e63

Please sign in to comment.