diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..6a9490e --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,14 @@ +version: 2 + +build: + os: ubuntu-22.04 + tools: + python: "3.11" + +python: + install: + - requirements: docs/requirements.txt + +sphinx: + configuration: docs/conf.py + diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 28ed69a..0000000 --- a/.travis.yml +++ /dev/null @@ -1,86 +0,0 @@ -language: - - python - -install: - - .travis/install.sh - -script: - - export PATH="/usr/local/opt/openssl/bin:$PATH" - - export PATH="$HOME/.pyenv/shims:$HOME/.pyenv/bin:$PATH" - - travis_wait 60 tox - -matrix: - fast_finish: true - include: - - os: linux - sudo: required - dist: xenial - python: 3.6 - env: TOXENV=py36 - - os: linux - sudo: required - dist: xenial - python: 3.6 - env: TOXENV=py36 USE_UVLOOP=1 - - os: linux - sudo: required - dist: xenial - python: 3.7 - env: TOXENV=py37 - - os: linux - sudo: required - dist: xenial - python: 3.7 - env: TOXENV=py37 USE_UVLOOP=1 - - os: linux - sudo: required - dist: xenial - python: 3.8 - env: TOXENV=py38 - - os: linux - sudo: required - dist: xenial - python: 3.8 - env: TOXENV=py38 USE_UVLOOP=1 - - os: linux - sudo: required - dist: xenial - python: 3.9-dev - env: TOXENV=py39 - - os: linux - sudo: required - dist: xenial - python: 3.9-dev - env: TOXENV=py39 USE_UVLOOP=1 - - os: osx - osx_image: xcode9.4 - language: generic - env: PYENV_VERSION=3.6-dev TOXENV=py36 - - os: osx - osx_image: xcode9.4 - language: generic - env: PYENV_VERSION=3.6-dev TOXENV=py36 USE_UVLOOP=1 - - os: osx - osx_image: xcode9.4 - language: generic - env: PYENV_VERSION=3.7-dev TOXENV=py37 - - os: osx - osx_image: xcode9.4 - language: generic - env: PYENV_VERSION=3.7-dev TOXENV=py37 USE_UVLOOP=1 - - os: osx - osx_image: xcode11.2 - language: generic - env: PYENV_VERSION=3.8-dev TOXENV=py38 - - os: osx - osx_image: xcode11.2 - language: generic - env: PYENV_VERSION=3.8-dev TOXENV=py38 USE_UVLOOP=1 - - os: osx - osx_image: xcode12.2 - language: generic - env: PYENV_VERSION=3.9-dev TOXENV=py39 - - os: osx - osx_image: xcode12.2 - language: generic - env: PYENV_VERSION=3.9-dev TOXENV=py39 USE_UVLOOP=1 diff --git a/.travis/install.sh b/.travis/install.sh deleted file mode 100755 index ea7d72a..0000000 --- a/.travis/install.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -if [[ $TRAVIS_OS_NAME == 'osx' ]]; then - export CPPFLAGS="-I/usr/local/opt/openssl/include" - export LDFLAGS="-L/usr/local/opt/openssl/lib -L/usr/local/opt/libffi/lib" - export PATH="$HOME/.pyenv/bin:/usr/local/opt/openssl/bin:$PATH" - - brew update - brew install libffi libsodium - eval "$(pyenv init -)" - pyenv install $PYENV_VERSION - pyenv global $PYENV_VERSION - pyenv rehash -else - git clone git://github.com/jedisct1/libsodium.git - cd libsodium - ./autogen.sh - ./configure - make && sudo make install - sudo ldconfig -fi - -pip install tox diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 5697b67..c6e7862 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -68,19 +68,18 @@ contributors list. Branches -------- -There are two long-lived branches in AsyncSSH at the moment: +There are two long-lived branches in AsyncSSH: * The master branch is intended to contain the latest stable version of the code. All official versions of AsyncSSH are released from this branch, and each release has a corresponding tag added - matching its release number. Bug fixes and simple improvements - may be checked directly into this branch, but most new features - will be added to the develop branch first. - -* The develop branch is intended to contain features for developers - to test before they are ready to be added to an official release. - APIs in the develop branch may be subject to change until they - are migrated back to master, and there's no guarantee of backward + matching its release number. + +* The develop branch is intended to contain new features and bug fixes + ready to be tested before being added to an official release. APIs + in the develop branch may be subject to change until they are + migrated back to master, and there's no guarantee of backward compatibility in this branch. However, pulling from this branch will provide early access to new functionality and a chance to - influence this functionality before it is released. + influence this functionality before it is released. Also, all + pull requests should be submitted against this branch. diff --git a/README.rst b/README.rst index 959e70b..5203feb 100644 --- a/README.rst +++ b/README.rst @@ -1,3 +1,12 @@ +.. image:: https://readthedocs.org/projects/asyncssh/badge/?version=latest + :target: https://asyncssh.readthedocs.io/en/latest/?badge=latest + :alt: Documentation Status + +.. image:: https://img.shields.io/pypi/v/asyncssh.svg + :target: https://pypi.python.org/pypi/asyncssh/ + :alt: AsyncSSH PyPI Project + + AsyncSSH: Asynchronous SSH for Python ===================================== @@ -32,6 +41,7 @@ Features * Environment variables, terminal type, and window size * Direct and forwarded TCP/IP channels * OpenSSH-compatible direct and forwarded UNIX domain socket channels + * OpenSSH-compatible TUN/TAP channels and packet forwarding * Local and remote TCP/IP port forwarding * Local and remote UNIX domain socket forwarding * Dynamic TCP/IP port forwarding via SOCKS @@ -46,6 +56,9 @@ Features * Multiple SSH connections in a single event loop * Byte and string based I/O with settable encoding * A variety of `key exchange`__, `encryption`__, and `MAC`__ algorithms + + * Including post-quantum kex algorithms ML-KEM and SNTRUP + * Support for `gzip compression`__ * Including OpenSSH variant to delay compression until after auth @@ -88,7 +101,7 @@ License This package is released under the following terms: - Copyright (c) 2013-2022 by Ron Frederick and others. + Copyright (c) 2013-2024 by Ron Frederick and others. This program and the accompanying materials are made available under the terms of the Eclipse Public License v2.0 which accompanies this @@ -114,7 +127,7 @@ Prerequisites To use AsyncSSH 2.0 or later, you need the following: * Python 3.6 or later -* cryptography (PyCA) 2.8 or later +* cryptography (PyCA) 3.1 or later Installation ------------ @@ -143,6 +156,10 @@ functionality: * Install gssapi from https://pypi.python.org/pypi/gssapi if you want support for GSSAPI key exchange and authentication on UNIX. +* Install liboqs from https://github.com/open-quantum-safe/liboqs + if you want support for the OpenSSH post-quantum key exchange + algorithms based on ML-KEM and SNTRUP. + * Install libsodium from https://github.com/jedisct1/libsodium and libnacl from https://pypi.python.org/pypi/libnacl if you have a version of OpenSSL older than 1.1.1b installed and you want @@ -186,8 +203,8 @@ Windows, you can run: Note that you will still need to manually install the libsodium library listed above for libnacl to work correctly and/or libnettle for UMAC -support. Unfortunately, since libsodium and libnettle are not Python -packages, they cannot be directly installed using pip. +support. Unfortunately, since liboqs, libsodium, and libnettle are not +Python packages, they cannot be directly installed using pip. Installing the development branch ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..ca92b4b --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,22 @@ +# Secuity Policy + +## Supported Versions + +AsyncSSH has only one active development branch at this time. Any bug or +vulnerability fixes will be fixed in the "develop" branch first and then +migrated to the "master" branch in preparation for putting out a new release. + +## Reporting Vulnerabilities + +**⚠️ Please do not file GitHub issues for security vulnerabilities as they are +public! ⚠️** + +If you believe you have found a security vulnerability in AsyncSSH, please +create a draft +[security advisory](https://github.com/ronf/asyncssh/security/advisories/new) +or send an e-mail to security@asyncssh.com with a description of the issue +and details for how to reproduce it. This report will be reviewed and you'll +be contacted if further information is required, or when a fix is available. + +Published security advisories for AsyncSSH can be found +[here](https://github.com/ronf/asyncssh/security/advisories). diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index 02bb5af..0000000 --- a/appveyor.yml +++ /dev/null @@ -1,21 +0,0 @@ -environment: - matrix: - - TOXENV: py36 - -image: - - Visual Studio 2017 - -platform: - - x86 - - x64 - -install: - - "curl https://www.timeheart.net/appveyor/%PLATFORM%/libsodium-18.dll -O" - - "curl https://www.timeheart.net/appveyor/%PLATFORM%/libnettle-6.dll -O" - - "curl https://www.timeheart.net/appveyor/%PLATFORM%/libhogweed-4.dll -O" - - pip install tox - -build: off - -test_script: - - tox diff --git a/asyncssh/__init__.py b/asyncssh/__init__.py index 9b388c3..fb316e0 100644 --- a/asyncssh/__init__.py +++ b/asyncssh/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2022 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -34,7 +34,7 @@ from .auth_keys import import_authorized_keys, read_authorized_keys from .channel import SSHClientChannel, SSHServerChannel -from .channel import SSHTCPChannel, SSHUNIXChannel +from .channel import SSHTCPChannel, SSHUNIXChannel, SSHTunTapChannel from .client import SSHClient @@ -44,9 +44,10 @@ from .connection import SSHAcceptor, SSHClientConnection, SSHServerConnection from .connection import SSHClientConnectionOptions, SSHServerConnectionOptions +from .connection import SSHAcceptHandler from .connection import create_connection, create_server, connect, listen from .connection import connect_reverse, listen_reverse, get_server_host_key -from .connection import get_server_auth_methods +from .connection import get_server_auth_methods, run_client, run_server from .editor import SSHLineEditorChannel @@ -86,10 +87,12 @@ from .public_key import load_keypairs, load_public_keys, load_certificates from .public_key import load_resident_keys +from .rsa import set_default_skip_rsa_key_validation + from .scp import scp from .session import DataType, SSHClientSession, SSHServerSession -from .session import SSHTCPSession, SSHUNIXSession +from .session import SSHTCPSession, SSHUNIXSession, SSHTunTapSession from .server import SSHServer @@ -106,7 +109,7 @@ from .sftp import SFTPFileCorrupt, SFTPOwnerInvalid, SFTPGroupInvalid from .sftp import SFTPNoMatchingByteRangeLock from .sftp import SFTPConnectionLost, SFTPOpUnsupported -from .sftp import SFTPAttrs, SFTPVFSAttrs, SFTPName +from .sftp import SFTPAttrs, SFTPVFSAttrs, SFTPName, SFTPLimits from .sftp import SEEK_SET, SEEK_CUR, SEEK_END from .stream import SSHSocketSessionFactory, SSHServerSessionFactory @@ -117,3 +120,54 @@ # Import these explicitly to trigger register calls in them from . import sk_eddsa, sk_ecdsa, eddsa, ecdsa, rsa, dsa, kex_dh, kex_rsa + +__all__ = [ + '__author__', '__author_email__', '__url__', '__version__', + 'BreakReceived', 'BytesOrStr', 'ChannelListenError', + 'ChannelOpenError', 'CompressionError', 'ConfigParseError', + 'ConnectionLost', 'DEVNULL', 'DataType', 'DisconnectError', 'Error', + 'HostKeyNotVerifiable', 'IllegalUserName', 'KeyEncryptionError', + 'KeyExchangeFailed', 'KeyExportError', 'KeyGenerationError', + 'KeyImportError', 'MACError', 'PIPE', 'PasswordChangeRequired', + 'PermissionDenied', 'ProcessError', 'ProtocolError', + 'ProtocolNotSupported', 'SEEK_CUR', 'SEEK_END', 'SEEK_SET', + 'SFTPAttrs', 'SFTPBadMessage', 'SFTPByteRangeLockConflict', + 'SFTPByteRangeLockRefused', 'SFTPCannotDelete', 'SFTPClient', + 'SFTPClientFile', 'SFTPConnectionLost', 'SFTPDeletePending', + 'SFTPDirNotEmpty', 'SFTPEOFError', 'SFTPError', 'SFTPFailure', + 'SFTPFileAlreadyExists', 'SFTPFileCorrupt', 'SFTPFileIsADirectory', + 'SFTPGroupInvalid', 'SFTPInvalidFilename', 'SFTPInvalidHandle', + 'SFTPInvalidParameter', 'SFTPLimits', 'SFTPLinkLoop', 'SFTPLockConflict', + 'SFTPName', 'SFTPNoConnection', 'SFTPNoMatchingByteRangeLock', + 'SFTPNoMedia', 'SFTPNoSpaceOnFilesystem', 'SFTPNoSuchFile', + 'SFTPNoSuchPath', 'SFTPNotADirectory', 'SFTPOpUnsupported', + 'SFTPOwnerInvalid', 'SFTPPermissionDenied', 'SFTPQuotaExceeded', + 'SFTPServer', 'SFTPServerFactory', 'SFTPUnknownPrincipal', 'SFTPVFSAttrs', + 'SFTPWriteProtect', 'SSHAcceptHandler', 'SSHAcceptor', 'SSHAgentClient', + 'SSHAgentKeyPair', 'SSHAuthorizedKeys', 'SSHCertificate', 'SSHClient', + 'SSHClientChannel', 'SSHClientConnection', 'SSHClientConnectionOptions', + 'SSHClientProcess', 'SSHClientSession', 'SSHCompletedProcess', + 'SSHForwarder', 'SSHKey', 'SSHKeyPair', 'SSHKnownHosts', + 'SSHLineEditorChannel', 'SSHListener', 'SSHReader', 'SSHServer', + 'SSHServerChannel', 'SSHServerConnection', + 'SSHServerConnectionOptions', 'SSHServerProcess', + 'SSHServerProcessFactory', 'SSHServerSession', + 'SSHServerSessionFactory', 'SSHSocketSessionFactory', + 'SSHSubprocessProtocol', 'SSHSubprocessReadPipe', + 'SSHSubprocessTransport', 'SSHSubprocessWritePipe', 'SSHTCPChannel', + 'SSHTCPSession', 'SSHTunTapChannel', 'SSHTunTapSession', + 'SSHUNIXChannel', 'SSHUNIXSession', 'SSHWriter', + 'STDOUT', 'ServiceNotAvailable', 'SignalReceived', 'TerminalSizeChanged', + 'TimeoutError', 'connect', 'connect_agent', 'connect_reverse', + 'create_connection', 'create_server', 'generate_private_key', + 'get_server_auth_methods', 'get_server_host_key', + 'import_authorized_keys', 'import_certificate', 'import_known_hosts', + 'import_private_key', 'import_public_key', 'listen', 'listen_reverse', + 'load_certificates', 'load_keypairs', 'load_pkcs11_keys', + 'load_public_keys', 'load_resident_keys', 'logger', 'match_known_hosts', + 'read_authorized_keys', 'read_certificate', 'read_certificate_list', + 'read_known_hosts', 'read_private_key', 'read_private_key_list', + 'read_public_key', 'read_public_key_list', 'run_client', 'run_server', + 'scp', 'set_debug_level', 'set_default_skip_rsa_key_validation', + 'set_log_level', 'set_sftp_log_level' +] diff --git a/asyncssh/agent.py b/asyncssh/agent.py index f2831a2..21a46dc 100644 --- a/asyncssh/agent.py +++ b/asyncssh/agent.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016-2021 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -21,12 +21,11 @@ """SSH agent client""" import asyncio -import errno import os import sys from types import TracebackType from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, Type, Union -from typing_extensions import Protocol +from typing_extensions import Protocol, Self from .listener import SSHForwardListener from .misc import async_context_manager, maybe_wait_closed @@ -58,17 +57,10 @@ async def wait_closed(self) -> None: """Wait for the connection to the SSH agent to close""" -try: - if sys.platform == 'win32': # pragma: no cover - from .agent_win32 import open_agent - else: - from .agent_unix import open_agent -except ImportError as _exc: # pragma: no cover - async def open_agent(agent_path: Optional[str]) -> \ - Tuple[AgentReader, AgentWriter]: - """Dummy function if we're unable to import agent support""" - - raise OSError(errno.ENOENT, 'Agent support unavailable: %s' % str(_exc)) +if sys.platform == 'win32': # pragma: no cover + from .agent_win32 import open_agent +else: + from .agent_unix import open_agent class _SupportsOpenAgentConnection(Protocol): @@ -78,7 +70,7 @@ async def open_agent_connection(self) -> Tuple[AgentReader, AgentWriter]: """Open a forwarded ssh-agent connection back to the client""" -_AgentPath = Union[None, str, _SupportsOpenAgentConnection] +_AgentPath = Union[str, _SupportsOpenAgentConnection] # Client request message numbers @@ -149,6 +141,18 @@ def __init__(self, agent: 'SSHAgentClient', algorithm: bytes, self._is_cert = is_cert self._flags = 0 + @property + def has_cert(self) -> bool: + """ Return if this key pair has an associated cert""" + + return self._is_cert + + @property + def has_x509_chain(self) -> bool: + """ Return if this key pair has an associated X.509 cert chain""" + + return False + def set_certificate(self, cert: SSHCertificate) -> None: """Set certificate to use with this key""" @@ -161,7 +165,7 @@ def set_sig_algorithm(self, sig_algorithm: bytes) -> None: super().set_sig_algorithm(sig_algorithm) - if sig_algorithm == b'rsa-sha2-256': + if sig_algorithm in (b'rsa-sha2-256', b'x509v3-rsa2048-sha256'): self._flags |= SSH_AGENT_RSA_SHA2_256 elif sig_algorithm == b'rsa-sha2-512': self._flags |= SSH_AGENT_RSA_SHA2_512 @@ -186,7 +190,7 @@ def __init__(self, agent_path: _AgentPath): self._writer: Optional[AgentWriter] = None self._lock = asyncio.Lock() - async def __aenter__(self) -> 'SSHAgentClient': + async def __aenter__(self) -> Self: """Allow SSHAgentClient to be used as an async context manager""" return self @@ -222,7 +226,7 @@ def encode_constraints(lifetime: Optional[int], confirm: bool) -> bytes: async def connect(self) -> None: """Connect to the SSH agent""" - if isinstance(self._agent_path, str) or self._agent_path is None: + if isinstance(self._agent_path, str): self._reader, self._writer = await open_agent(self._agent_path) else: self._reader, self._writer = \ @@ -248,7 +252,7 @@ async def _make_request(self, msgtype: int, *args: bytes) -> \ resplen = int.from_bytes((await reader.readexactly(4)), 'big') - resp = SSHPacket((await reader.readexactly(resplen))) + resp = SSHPacket(await reader.readexactly(resplen)) resptype = resp.get_byte() return resptype, resp @@ -294,7 +298,7 @@ async def get_keys(self, identities: Optional[Sequence[bytes]] = None) -> \ resp.check_end() return result else: - raise ValueError('Unknown SSH agent response: %d' % resptype) + raise ValueError(f'Unknown SSH agent response: {resptype}') async def sign(self, key_blob: bytes, data: bytes, flags: int = 0) -> bytes: @@ -311,10 +315,11 @@ async def sign(self, key_blob: bytes, data: bytes, elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to sign with requested key') else: - raise ValueError('Unknown SSH agent response: %d' % resptype) + raise ValueError(f'Unknown SSH agent response: {resptype}') async def add_keys(self, keylist: KeyPairListArg = (), - passphrase: str = None, lifetime: int = None, + passphrase: Optional[str] = None, + lifetime: Optional[int] = None, confirm: bool = False) -> None: """Add keys to the agent @@ -392,10 +397,11 @@ async def add_keys(self, keylist: KeyPairListArg = (), if not ignore_failures: raise ValueError('Unable to add key') else: - raise ValueError('Unknown SSH agent response: %d' % resptype) + raise ValueError(f'Unknown SSH agent response: {resptype}') - async def add_smartcard_keys(self, provider: str, pin: str = None, - lifetime: int = None, + async def add_smartcard_keys(self, provider: str, + pin: Optional[str] = None, + lifetime: Optional[int] = None, confirm: bool = False) -> None: """Store keys associated with a smart card in the agent @@ -432,7 +438,7 @@ async def add_smartcard_keys(self, provider: str, pin: str = None, elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to add keys') else: - raise ValueError('Unknown SSH agent response: %d' % resptype) + raise ValueError(f'Unknown SSH agent response: {resptype}') async def remove_keys(self, keylist: Sequence[SSHKeyPair]) -> None: """Remove a key stored in the agent @@ -455,10 +461,10 @@ async def remove_keys(self, keylist: Sequence[SSHKeyPair]) -> None: elif resptype == SSH_AGENT_FAILURE: raise ValueError('Key not found') else: - raise ValueError('Unknown SSH agent response: %d' % resptype) + raise ValueError(f'Unknown SSH agent response: {resptype}') async def remove_smartcard_keys(self, provider: str, - pin: str = None) -> None: + pin: Optional[str] = None) -> None: """Remove keys associated with a smart card stored in the agent :param provider: @@ -481,7 +487,7 @@ async def remove_smartcard_keys(self, provider: str, elif resptype == SSH_AGENT_FAILURE: raise ValueError('Keys not found') else: - raise ValueError('Unknown SSH agent response: %d' % resptype) + raise ValueError(f'Unknown SSH agent response: {resptype}') async def remove_all(self) -> None: """Remove all keys stored in the agent @@ -498,7 +504,7 @@ async def remove_all(self) -> None: elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to remove all keys') else: - raise ValueError('Unknown SSH agent response: %d' % resptype) + raise ValueError(f'Unknown SSH agent response: {resptype}') async def lock(self, passphrase: str) -> None: """Lock the agent using the specified passphrase @@ -522,7 +528,7 @@ async def lock(self, passphrase: str) -> None: elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to lock SSH agent') else: - raise ValueError('Unknown SSH agent response: %d' % resptype) + raise ValueError(f'Unknown SSH agent response: {resptype}') async def unlock(self, passphrase: str) -> None: """Unlock the agent using the specified passphrase @@ -546,7 +552,7 @@ async def unlock(self, passphrase: str) -> None: elif resptype == SSH_AGENT_FAILURE: raise ValueError('Unable to unlock SSH agent') else: - raise ValueError('Unknown SSH agent response: %d' % resptype) + raise ValueError(f'Unknown SSH agent response: {resptype}') async def query_extensions(self) -> Sequence[str]: """Return a list of extensions supported by the agent @@ -575,7 +581,7 @@ async def query_extensions(self) -> Sequence[str]: elif resptype == SSH_AGENT_FAILURE: return [] else: - raise ValueError('Unknown SSH agent response: %d' % resptype) + raise ValueError(f'Unknown SSH agent response: {resptype}') def close(self) -> None: """Close the SSH agent connection @@ -626,7 +632,7 @@ def close(self) -> None: @async_context_manager -async def connect_agent(agent_path: _AgentPath = None) -> 'SSHAgentClient': +async def connect_agent(agent_path: _AgentPath = '') -> 'SSHAgentClient': """Make a connection to the SSH agent This function attempts to connect to an ssh-agent process @@ -654,7 +660,7 @@ async def connect_agent(agent_path: _AgentPath = None) -> 'SSHAgentClient': """ if not agent_path: - agent_path = os.environ.get('SSH_AUTH_SOCK', None) + agent_path = os.environ.get('SSH_AUTH_SOCK', '') agent = SSHAgentClient(agent_path) await agent.connect() diff --git a/asyncssh/agent_unix.py b/asyncssh/agent_unix.py index 1c0c4ce..63d9a6b 100644 --- a/asyncssh/agent_unix.py +++ b/asyncssh/agent_unix.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016-2021 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -22,7 +22,7 @@ import asyncio import errno -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Tuple if TYPE_CHECKING: @@ -30,8 +30,7 @@ from .agent import AgentReader, AgentWriter -async def open_agent(agent_path: Optional[str]) -> \ - Tuple['AgentReader', 'AgentWriter']: +async def open_agent(agent_path: str) -> Tuple['AgentReader', 'AgentWriter']: """Open a connection to ssh-agent""" if not agent_path: diff --git a/asyncssh/agent_win32.py b/asyncssh/agent_win32.py index 608ddb3..f162f58 100644 --- a/asyncssh/agent_win32.py +++ b/asyncssh/agent_win32.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016-2021 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -27,7 +27,7 @@ import ctypes import ctypes.wintypes import errno -from typing import TYPE_CHECKING, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Tuple, Union, cast from .misc import open_file @@ -78,7 +78,7 @@ class _PageantTransport: """Transport to connect to Pageant agent on Windows""" def __init__(self) -> None: - self._mapname = '%s%08x' % (_AGENT_NAME, win32api.GetCurrentThreadId()) + self._mapname = f'{_AGENT_NAME}{win32api.GetCurrentThreadId():08x}' try: self._mapfile = mmapfile.mmapfile('', self._mapname, @@ -164,8 +164,7 @@ async def wait_closed(self) -> None: """Wait for the transport to close""" -async def open_agent(agent_path: Optional[str]) -> \ - Tuple['AgentReader', 'AgentWriter']: +async def open_agent(agent_path: str) -> Tuple['AgentReader', 'AgentWriter']: """Open a connection to the Pageant or Windows 10 OpenSSH agent""" transport: Union[None, _PageantTransport, _W10OpenSSHTransport] = None @@ -178,7 +177,6 @@ async def open_agent(agent_path: Optional[str]) -> \ agent_path = _DEFAULT_OPENSSH_PATH if not transport: - assert agent_path is not None transport = _W10OpenSSHTransport(agent_path) return transport, transport diff --git a/asyncssh/asn1.py b/asyncssh/asn1.py index fc7ec1f..6a7293b 100644 --- a/asyncssh/asn1.py +++ b/asyncssh/asn1.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -169,8 +169,8 @@ def __init__(self, tag: int, content: bytes, asn1_class: int): self.content = content def __repr__(self) -> str: - return ('RawDERObject(%s, %s, %r)' % - (_asn1_class[self.asn1_class], self.tag, self.content)) + return f'RawDERObject({_asn1_class[self.asn1_class]}, ' \ + f'{self.tag}, {self.content!r})' def __eq__(self, other: object) -> bool: if not isinstance(other, RawDERObject): # pragma: no cover @@ -213,10 +213,10 @@ def __init__(self, tag: int, value: object, def __repr__(self) -> str: if self.asn1_class == CONTEXT_SPECIFIC: - return 'TaggedDERObject(%s, %r)' % (self.tag, self.value) + return f'TaggedDERObject({self.tag}, {self.value!r})' else: - return ('TaggedDERObject(%s, %s, %r)' % - (_asn1_class[self.asn1_class], self.tag, self.value)) + return f'TaggedDERObject({_asn1_class[self.asn1_class]}, ' \ + f'{self.tag}, {self.value!r})' def __eq__(self, other: object) -> bool: if not isinstance(other, TaggedDERObject): # pragma: no cover @@ -469,7 +469,7 @@ def __str__(self) -> str: return result def __repr__(self) -> str: - return "BitString('%s')" % self + return f"BitString('{self}')" def __eq__(self, other: object) -> bool: if not isinstance(other, BitString): # pragma: no cover @@ -508,10 +508,10 @@ def __init__(self, value: Union[bytes, bytearray]): self.value = value def __str__(self) -> str: - return '%s' % self.value.decode('ascii') + return self.value.decode('ascii') def __repr__(self) -> str: - return 'IA5String(%r)' % self.value + return f'IA5String({self.value!r})' def __eq__(self, other: object) -> bool: # pragma: no cover if not isinstance(other, IA5String): @@ -569,7 +569,7 @@ def __str__(self) -> str: return self.value def __repr__(self) -> str: - return "ObjectIdentifier('%s')" % self.value + return f"ObjectIdentifier('{self.value}')" def __eq__(self, other: object) -> bool: if not isinstance(other, ObjectIdentifier): # pragma: no cover @@ -685,7 +685,7 @@ def der_encode(value: object) -> bytes: identifier = cls.identifier content = cls.encode(value) else: - raise ASN1EncodeError('Cannot DER encode type %s' % t.__name__) + raise ASN1EncodeError(f'Cannot DER encode type {t.__name__}') length = len(content) if length < 0x80: diff --git a/asyncssh/auth.py b/asyncssh/auth.py index 40ed3aa..8cfd041 100644 --- a/asyncssh/auth.py +++ b/asyncssh/auth.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2022 by Ron Frederick and others. +# Copyright (c) 2013-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -27,6 +27,7 @@ from .gss import GSSBase, GSSError from .logging import SSHLogger from .misc import ProtocolError, PasswordChangeRequired, get_symbol_names +from .misc import run_in_executor from .packet import Boolean, String, UInt32, SSHPacket, SSHPacketHandler from .public_key import SigningKey from .saslprep import saslprep, SASLPrepError @@ -158,7 +159,7 @@ async def _start(self) -> None: await self.send_request(key=self._conn.get_gss_context(), trivial=False) else: - self._conn.try_next_auth() + self._conn.try_next_auth(next_method=True) class _ClientGSSMICAuth(ClientAuth): @@ -180,10 +181,10 @@ async def _start(self) -> None: self._gss = self._conn.get_gss_context() self._gss.reset() - mechs = b''.join((String(mech) for mech in self._gss.mechs)) + mechs = b''.join(String(mech) for mech in self._gss.mechs) await self.send_request(UInt32(len(self._gss.mechs)), mechs) else: - self._conn.try_next_auth() + self._conn.try_next_auth(next_method=True) def _finish(self) -> None: """Finish client GSS MIC authentication""" @@ -199,8 +200,8 @@ def _finish(self) -> None: else: self.send_packet(MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE) - def _process_response(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_response(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS response from the server""" mech = packet.get_string() @@ -212,7 +213,7 @@ def _process_response(self, _pkttype: int, _pktid: int, raise ProtocolError('Mechanism mismatch') try: - token = self._gss.step() + token = await run_in_executor(self._gss.step) assert token is not None self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token)) @@ -223,10 +224,10 @@ def _process_response(self, _pkttype: int, _pktid: int, if exc.token: self.send_packet(MSG_USERAUTH_GSSAPI_ERRTOK, String(exc.token)) - self._conn.try_next_auth() + self._conn.try_next_auth(next_method=True) - def _process_token(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_token(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS token from the server""" token: Optional[bytes] = packet.get_string() @@ -235,7 +236,7 @@ def _process_token(self, _pkttype: int, _pktid: int, assert self._gss is not None try: - token = self._gss.step(token) + token = await run_in_executor(self._gss.step, token) if token: self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token)) @@ -246,7 +247,7 @@ def _process_token(self, _pkttype: int, _pktid: int, if exc.token: self.send_packet(MSG_USERAUTH_GSSAPI_ERRTOK, String(exc.token)) - self._conn.try_next_auth() + self._conn.try_next_auth(next_method=True) def _process_error(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: @@ -261,8 +262,8 @@ def _process_error(self, _pkttype: int, _pktid: int, self.logger.debug1('GSS error from server: %s', msg) self._got_error = True - def _process_error_token(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_error_token(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS error token from the server""" token = packet.get_string() @@ -271,7 +272,7 @@ def _process_error_token(self, _pkttype: int, _pktid: int, assert self._gss is not None try: - self._gss.step(token) + await run_in_executor(self._gss.step, token) except GSSError as exc: if not self._got_error: # pragma: no cover self.logger.debug1('GSS error from server: %s', str(exc)) @@ -294,7 +295,7 @@ async def _start(self) -> None: await self._conn.host_based_auth_requested() if keypair is None: - self._conn.try_next_auth() + self._conn.try_next_auth(next_method=True) return self.logger.debug1('Trying host based auth of user %s on host %s ' @@ -322,7 +323,7 @@ async def _start(self) -> None: self._keypair = await self._conn.public_key_auth_requested() if self._keypair is None: - self._conn.try_next_auth() + self._conn.try_next_auth(next_method=True) return self.logger.debug1('Trying public key auth with %s key', @@ -340,10 +341,14 @@ async def _send_signed_request(self) -> None: self.logger.debug1('Signing request with %s key', self._keypair.algorithm) - await self.send_request(Boolean(True), - String(self._keypair.algorithm), - String(self._keypair.public_data), - key=self._keypair, trivial=False) + try: + await self.send_request(Boolean(True), + String(self._keypair.algorithm), + String(self._keypair.public_data), + key=self._keypair, trivial=False) + except ValueError as exc: + self.logger.debug1('Public key auth failed: %s', str(exc)) + self._conn.try_next_auth() def _process_public_key_ok(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: @@ -377,7 +382,7 @@ async def _start(self) -> None: submethods = await self._conn.kbdint_auth_requested() if submethods is None: - self._conn.try_next_auth() + self._conn.try_next_auth(next_method=True) return self.logger.debug1('Trying keyboard-interactive auth') @@ -393,7 +398,7 @@ async def _receive_challenge(self, name: str, instruction: str, lang: str, lang, prompts) if responses is None: - self._conn.try_next_auth() + self._conn.try_next_auth(next_method=True) return self.send_packet(MSG_USERAUTH_INFO_RESPONSE, UInt32(len(responses)), @@ -454,7 +459,7 @@ async def _start(self) -> None: password = await self._conn.password_auth_requested() if password is None: - self._conn.try_next_auth() + self._conn.try_next_auth(next_method=True) return self.logger.debug1('Trying password auth') @@ -469,7 +474,7 @@ async def _change_password(self, prompt: str, lang: str) -> None: if result == NotImplemented: # Password change not supported - move on to the next auth method - self._conn.try_next_auth() + self._conn.try_next_auth(next_method=True) return self.logger.debug1('Trying to chsnge password') @@ -543,10 +548,10 @@ def send_failure(self, partial_success: bool = False) -> None: self._conn.send_userauth_failure(partial_success) - def send_success(self) -> None: + async def send_success(self) -> None: """Send a user authentication success response""" - self._conn.send_userauth_success() + await self._conn.send_userauth_success() class _ServerNullAuth(ServerAuth): @@ -591,7 +596,7 @@ async def _start(self, packet: SSHPacket) -> None: (await self._conn.validate_gss_principal(self._username, self._gss.user, self._gss.host))): - self.send_success() + await self.send_success() else: self.send_failure() @@ -635,6 +640,7 @@ async def _start(self, packet: SSHPacket) -> None: return self.logger.debug1('Trying GSS MIC auth') + self._gss.reset() self.send_packet(MSG_USERAUTH_GSSAPI_RESPONSE, String(match)) @@ -644,19 +650,19 @@ async def _finish(self) -> None: if (await self._conn.validate_gss_principal(self._username, self._gss.user, self._gss.host)): - self.send_success() + await self.send_success() else: self.send_failure() - def _process_token(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_token(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS token from the client""" token: Optional[bytes] = packet.get_string() packet.check_end() try: - token = self._gss.step(token) + token = await run_in_executor(self._gss.step, token) if token: self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token)) @@ -681,15 +687,15 @@ def _process_exchange_complete(self, _pkttype: int, _pktid: int, else: self.send_failure() - def _process_error_token(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_error_token(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS error token from the client""" token = packet.get_string() packet.check_end() try: - self._gss.step(token) + await run_in_executor(self._gss.step, token) except GSSError as exc: self.logger.debug1('GSS error from client: %s', str(exc)) @@ -751,7 +757,7 @@ async def _start(self, packet: SSHPacket) -> None: key_data, client_host, client_username, msg, signature)): - self.send_success() + await self.send_success() else: self.send_failure() @@ -789,7 +795,7 @@ async def _start(self, packet: SSHPacket) -> None: if (await self._conn.validate_public_key(self._username, key_data, msg, signature)): if sig_present: - self.send_success() + await self.send_success() else: self.send_packet(MSG_USERAUTH_PK_OK, String(algorithm), String(key_data)) @@ -826,9 +832,9 @@ async def _start(self, packet: SSHPacket) -> None: challenge = await self._conn.get_kbdint_challenge(self._username, lang, submethods) - self._send_challenge(challenge) + await self._send_challenge(challenge) - def _send_challenge(self, challenge: KbdIntChallenge) -> None: + async def _send_challenge(self, challenge: KbdIntChallenge) -> None: """Send a keyboard interactive authentication request""" if isinstance(challenge, (tuple, list)): @@ -842,7 +848,7 @@ def _send_challenge(self, challenge: KbdIntChallenge) -> None: String(instruction), String(lang), UInt32(num_prompts), *prompts_bytes) elif challenge: - self.send_success() + await self.send_success() else: self.send_failure() @@ -851,7 +857,7 @@ async def _validate_response(self, responses: KbdIntResponse) -> None: next_challenge = \ await self._conn.validate_kbdint_response(self._username, responses) - self._send_challenge(next_challenge) + await self._send_challenge(next_challenge) def _process_info_response(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: @@ -916,7 +922,7 @@ async def _start(self, packet: SSHPacket) -> None: await self._conn.validate_password(self._username, password) if result: - self.send_success() + await self.send_success() else: self.send_failure() except PasswordChangeRequired as exc: diff --git a/asyncssh/auth_keys.py b/asyncssh/auth_keys.py index 4a1c517..759fe0e 100644 --- a/asyncssh/auth_keys.py +++ b/asyncssh/auth_keys.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2021 by Ron Frederick and others. +# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -120,8 +120,8 @@ def _add_permitopen(self, option: str, value: str) -> None: host = host[1:-1] port = None if port_str == '*' else int(port_str) - except: - raise ValueError('Illegal permitopen value: %s' % value) from None + except ValueError: + raise ValueError(f'Illegal permitopen value: {value}') from None permitted_opens = cast(Set[Tuple[str, Optional[int]]], self.options.setdefault(option, set())) diff --git a/asyncssh/channel.py b/asyncssh/channel.py index a957138..b0e3d0c 100644 --- a/asyncssh/channel.py +++ b/asyncssh/channel.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2022 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -26,6 +26,7 @@ import inspect import re import signal as _signal +import sys from types import MappingProxyType from typing import TYPE_CHECKING, Any, AnyStr, Awaitable, Callable from typing import Dict, Generic, Iterable, List, Mapping, Optional @@ -45,19 +46,23 @@ from .logging import SSHLogger -from .misc import ChannelOpenError, MaybeAwait, ProtocolError -from .misc import get_symbol_names, map_handler_name +from .misc import ChannelOpenError, EnvMap, MaybeAwait, ProtocolError +from .misc import TermModes, TermSize, TermSizeArg +from .misc import decode_env, encode_env, get_symbol_names, map_handler_name from .packet import Boolean, Byte, String, UInt32, SSHPacket, SSHPacketHandler -from .session import TermModes, TermSize, TermSizeArg from .session import SSHSession, SSHClientSession, SSHServerSession -from .session import SSHTCPSession, SSHUNIXSession +from .session import SSHTCPSession, SSHUNIXSession, SSHTunTapSession from .session import SSHSessionFactory, SSHClientSessionFactory from .session import SSHTCPSessionFactory, SSHUNIXSessionFactory +from .session import SSHTunTapSessionFactory from .stream import DataType +from .tuntap import SSH_TUN_MODE_POINTTOPOINT, SSH_TUN_UNIT_ANY +from .tuntap import SSH_TUN_AF_INET, SSH_TUN_AF_INET6 + if TYPE_CHECKING: # pylint: disable=cyclic-import @@ -75,7 +80,7 @@ _signal_names = {v: k for (k, v) in _signal_numbers.items()} _ExitSignal = Tuple[str, bool, str, str] -_RequestHandler = Callable[[SSHPacket], Optional[bool]] +_RequestHandler = Optional[Callable[[SSHPacket], Optional[bool]]] class SSHChannel(Generic[AnyStr], SSHPacketHandler): @@ -111,7 +116,9 @@ def __init__(self, conn: 'SSHConnection', self._send_high_water: int self._send_low_water: int - self._env: Dict[str, str] = {} + self._env: Dict[bytes, bytes] = {} + self._str_env: Optional[Dict[str, str]] = None + self._command: Optional[str] = None self._subsystem: Optional[str] = None @@ -139,8 +146,7 @@ def __init__(self, conn: 'SSHConnection', self._recv_chan: Optional[int] = conn.add_channel(self) - self._logger = conn.logger.get_child(context='chan=%d' % - self._recv_chan) + self._logger = conn.logger.get_child(context=f'chan={self._recv_chan}') self.set_encoding(encoding, errors) self.set_write_buffer_limits() @@ -198,7 +204,7 @@ def get_write_datatypes(self) -> Set[int]: return self._write_datatypes - def _cleanup(self, exc: Exception = None) -> None: + def _cleanup(self, exc: Optional[Exception] = None) -> None: """Clean up this channel""" if self._open_waiter: @@ -220,7 +226,13 @@ def _cleanup(self, exc: Exception = None) -> None: self._request_waiters = [] if self._session is not None: - self._session.connection_lost(exc) + # pylint: disable=broad-except + try: + self._session.connection_lost(exc) + except Exception: + self.logger.debug1('Uncaught exception in session ignored', + exc_info=sys.exc_info) + self._session = None self._close_event.set() @@ -322,7 +334,7 @@ def _flush_send_buf(self) -> None: elif self._send_state == 'close_pending': self._close_send() - def _flush_recv_buf(self, exc: Exception = None) -> None: + def _flush_recv_buf(self, exc: Optional[Exception] = None) -> None: """Flush as much data in the recv buffer as the application allows""" while self._recv_buf and not self._recv_paused: @@ -394,19 +406,6 @@ def _accept_data(self, data: bytes, datatype: DataType = None) -> None: if self._send_state in {'close_pending', 'closed'}: return - datalen = len(data) - - if datalen > self._recv_window: - raise ProtocolError('Window exceeded') - - if datatype: - typename = ' from %s' % _data_type_names[datatype] - else: - typename = '' - - self.logger.debug2('Received %d data byte%s%s', datalen, - 's' if datalen > 1 else '', typename) - if self._recv_paused: self._recv_buf.append((data, datatype)) else: @@ -453,6 +452,8 @@ def process_connection_close(self, exc: Optional[Exception]) -> None: self.logger.info('Closing channel due to connection close') + self._send_state = 'closed' + self._close_send() self._cleanup(exc) def process_open(self, send_chan: int, send_window: int, send_pktsize: int, @@ -577,6 +578,14 @@ def _process_data(self, _pkttype: int, _pktid: int, data = packet.get_string() packet.check_end() + datalen = len(data) + + if datalen > self._recv_window: + raise ProtocolError('Window exceeded') + + self.logger.debug2('Received %d data byte%s', datalen, + 's' if datalen > 1 else '') + self._accept_data(data) def _process_extended_data(self, _pkttype: int, _pktid: int, @@ -593,6 +602,15 @@ def _process_extended_data(self, _pkttype: int, _pktid: int, if datatype not in self._read_datatypes: raise ProtocolError('Invalid extended data type') + datalen = len(data) + + if datalen > self._recv_window: + raise ProtocolError('Window exceeded') + + self.logger.debug2('Received %d data byte%s from %s', datalen, + 's' if datalen > 1 else '', + _data_type_names[datatype]) + self._accept_data(data, datatype) def _process_eof(self, _pkttype: int, _pktid: int, @@ -845,8 +863,8 @@ def get_write_buffer_size(self) -> int: return self._send_buf_len - def set_write_buffer_limits(self, high: int = None, - low: int = None) -> None: + def set_write_buffer_limits(self, high: Optional[int] = None, + low: Optional[int] = None) -> None: """Set the high- and low-water limits for write flow control This method sets the limits used when deciding when to call @@ -865,8 +883,8 @@ def set_write_buffer_limits(self, high: int = None, low = high // 4 if not 0 <= low <= high: - raise ValueError('high (%r) must be >= low (%r) must be >= 0' % - (high, low)) + raise ValueError(f'high (high) must be >= low ({low}) ' + 'must be >= 0') self.logger.debug1('Set write buffer limits: low-water=%d, ' 'high-water=%d', low, high) @@ -920,7 +938,7 @@ def write(self, data: AnyStr, datatype: DataType = None) -> None: datalen = len(encoded_data) if datatype: - typename = ' to %s' % _data_type_names[datatype] + typename = f' to {_data_type_names[datatype]}' else: typename = '' @@ -1005,7 +1023,7 @@ def resume_reading(self) -> None: This method can be called to resume delivery of incoming data which was suspended by a call to :meth:`pause_reading`. As soon as this method is called, any buffered data will be delivered - immediately. A pending end-of-file notication may also be + immediately. A pending end-of-file notification may also be delivered if one was queued while reading was paused. """ @@ -1020,12 +1038,41 @@ def get_environment(self) -> Mapping[str, str]: """Return the environment for this session This method returns the environment set by the client when - the session was opened. On the server, calls to this method - should only be made after :meth:`session_started - ` has been called on the - :class:`SSHServerSession`. When using the stream-based API, - calls to this can be made at any time after the handler - function has started up. + the session was opened. Keys and values are of type `str` + and this object only provides access to keys and values sent + as valid UTF-8 strings. Use :meth:`get_environment_bytes` + if you need to access environment variables with keys or + values containing binary data or non-UTF-8 encodings. + + On the server, calls to this method should only be made after + :meth:`session_started ` has + been called on the :class:`SSHServerSession`. When using the + stream-based API, calls to this can be made at any time after + the handler function has started up. + + :returns: A dictionary containing the environment variables + set by the client + + """ + + if self._str_env is None: + self._str_env = dict(decode_env(self._env)) + + return MappingProxyType(self._str_env) + + def get_environment_bytes(self) -> Mapping[bytes, bytes]: + """Return the environment for this session + + This method returns the environment set by the client when + the session was opened. Keys and values are of type `bytes` + and can include arbitrary binary data, with the exception + of NUL (\0) bytes. + + On the server, calls to this method should only be made after + :meth:`session_started ` has + been called on the :class:`SSHServerSession`. When using the + stream-based API, calls to this can be made at any time after + the handler function has started up. :returns: A dictionary containing the environment variables set by the client @@ -1040,7 +1087,7 @@ def get_command(self) -> Optional[str]: This method returns the command the client requested to execute when the session was opened, if any. If the client did not request that a command be executed, this method - will return `None`. On the server, alls to this method + will return `None`. On the server, calls to this method should only be made after :meth:`session_started ` has been called on the :class:`SSHServerSession`. When using the stream-based API, @@ -1087,7 +1134,7 @@ def __init__(self, conn: 'SSHClientConnection', async def create(self, session_factory: SSHClientSessionFactory[AnyStr], command: Optional[str], subsystem: Optional[str], - env: Dict[str, str], request_pty: bool, + env: Dict[bytes, bytes], request_pty: bool, term_type: Optional[str], term_size: TermSizeArg, term_modes: TermModes, x11_forwarding: Union[bool, str], x11_display: Optional[str], x11_auth_path: Optional[str], @@ -1109,10 +1156,16 @@ async def create(self, session_factory: SSHClientSessionFactory[AnyStr], self._command = command self._subsystem = subsystem - for name, env_value in env.items(): - self.logger.debug1(' Env: %s=%s', name, env_value) - self._send_request(b'env', String(str(name)), - String(str(env_value))) + for key, value in env.items(): + self.logger.debug1(' Env: %s=%s', key, value) + + if not isinstance(key, (bytes, str)): + key = str(key) + + if not isinstance(value, (bytes, str)): + value = str(value) + + self._send_request(b'env', String(key), String(value)) if request_pty: self.logger.debug1(' Terminal type: %s', term_type or 'None') @@ -1134,7 +1187,7 @@ async def create(self, session_factory: SSHClientSessionFactory[AnyStr], modes = b'' for mode, mode_value in term_modes.items(): if mode <= PTY_OP_END or mode >= PTY_OP_RESERVED: - raise ValueError('Invalid pty mode: %s' % mode) + raise ValueError(f'Invalid pty mode: {mode}') name = _pty_mode_names.get(mode, str(mode)) self.logger.debug2(' Mode %s: %d', name, mode_value) @@ -1392,7 +1445,7 @@ def send_signal(self, signal: Union[str, int]) -> None: try: signal = _signal_names[signal] except KeyError: - raise ValueError('Unknown signal: %s' % int(signal)) from None + raise ValueError(f'Unknown signal: {signal}') from None self.logger.info('Sending %s signal', signal) @@ -1450,8 +1503,8 @@ def __init__(self, conn: 'SSHServerConnection', super().__init__(conn, loop, encoding, errors, window, max_pktsize) - self._env = cast(Dict[str, str], - conn.get_key_option('environment', {})) + env_opt = cast(EnvMap, conn.get_key_option('environment', {})) + self._env = dict(encode_env(env_opt)) self._allow_pty = allow_pty self._line_editor = line_editor @@ -1606,23 +1659,16 @@ async def _finish_agent_req_request(self) -> None: def _process_env_request(self, packet: SSHPacket) -> bool: """Process a request to set an environment variable""" - name_bytes = packet.get_string() - value_bytes = packet.get_string() + key = packet.get_string() + value = packet.get_string() packet.check_end() - try: - name = name_bytes.decode('utf-8') - value = value_bytes.decode('utf-8') - except UnicodeDecodeError: - self.logger.debug1('Invalid environment data') - return False - - self.logger.debug1(' Env: %s=%s', name, value) - self._env[name] = value + self.logger.debug1(' Env: %s=%s', key, value) + self._env[key] = value return True - def _start_session(self, command: str = None, - subsystem: str = None) -> bool: + def _start_session(self, command: Optional[str] = None, + subsystem: Optional[str] = None) -> bool: """Tell the session what type of channel is being requested""" forced_command = \ @@ -1696,7 +1742,7 @@ def _process_window_change_request(self, packet: SSHPacket) -> bool: self.logger.info('Received window change: %sx%s (%sx%s pixels)', width, height, pixwidth, pixheight) else: - self.logger.info('Recceived window change: %sx%s', width, height) + self.logger.info('Received window change: %sx%s', width, height) self._term_size = (width, height, pixwidth, pixheight) self._session.terminal_size_changed(width, height, pixwidth, pixheight) @@ -1827,7 +1873,7 @@ def get_x11_display(self) -> Optional[str]: forwarding, this method returns `None`. :returns: A `str` containing the X11 display or `None` if - X11 fowarding was not requested + X11 forwarding was not requested """ @@ -1843,7 +1889,7 @@ def get_agent_path(self) -> Optional[str]: `None`. :returns: A `str` containing the ssh-agent socket path or - `None` if agent fowarding was not requested + `None` if agent forwarding was not requested """ @@ -2019,16 +2065,16 @@ async def connect(self, session_factory: SSHTCPSessionFactory[AnyStr], SSHTCPSession[AnyStr]: """Create a new outbound TCP session""" - return (await self._open_tcp(session_factory, b'direct-tcpip', - host, port, orig_host, orig_port)) + return await self._open_tcp(session_factory, b'direct-tcpip', + host, port, orig_host, orig_port) async def accept(self, session_factory: SSHTCPSessionFactory[AnyStr], host: str, port: int, orig_host: str, orig_port: int) -> SSHTCPSession[AnyStr]: """Create a new forwarded TCP session""" - return (await self._open_tcp(session_factory, b'forwarded-tcpip', - host, port, orig_host, orig_port)) + return await self._open_tcp(session_factory, b'forwarded-tcpip', + host, port, orig_host, orig_port) def set_inbound_peer_names(self, dest_host: str, dest_port: int, orig_host: str, orig_port: int) -> None: @@ -2078,6 +2124,54 @@ def set_inbound_peer_names(self, dest_path: str) -> None: self.set_extra_info(local_peername=dest_path, remote_peername='') +class SSHTunTapChannel(SSHForwardChannel[bytes]): + """SSH TunTap channel""" + + def __init__(self, conn: 'SSHConnection', + loop: asyncio.AbstractEventLoop, encoding: Optional[str], + errors: str, window: int, max_pktsize: int): + super().__init__(conn, loop, encoding, errors, window, max_pktsize) + + self._mode: Optional[int] = None + + def _accept_data(self, data: bytes, datatype: DataType = None) -> None: + """Strip off address family on incoming packets in TUN mode""" + + if self._mode == SSH_TUN_MODE_POINTTOPOINT: + data = data[4:] + + super()._accept_data(data, datatype) + + def write(self, data: bytes, datatype: DataType = None) -> None: + """Add address family in outbound packets in TUN mode""" + + if self._mode == SSH_TUN_MODE_POINTTOPOINT: + version = data[0] >> 4 + family = SSH_TUN_AF_INET if version == 4 else SSH_TUN_AF_INET6 + data = UInt32(family) + data + + super().write(data, datatype) + + async def open(self, session_factory: SSHTunTapSessionFactory, + mode: int, unit: Optional[int]) -> SSHTunTapSession: + """Open a TUN/TAP channel""" + + self._mode = mode + + if unit is None: + unit = SSH_TUN_UNIT_ANY + + return cast(SSHTunTapSession, + await self._open_forward(session_factory, + b'tun@openssh.com', + UInt32(mode), UInt32(unit))) + + def set_mode(self, mode: int) -> None: + """Set mode for inbound connections""" + + self._mode = mode + + class SSHX11Channel(SSHForwardChannel[bytes]): """SSH X11 channel""" diff --git a/asyncssh/client.py b/asyncssh/client.py index 3150283..af61c18 100644 --- a/asyncssh/client.py +++ b/asyncssh/client.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -39,6 +39,14 @@ class SSHClient: to receive callbacks when certain events occur on the SSH connection. + Whenever a new SSH client connection is opened, a corresponding + SSHClient object is created and the method :meth:`connection_made` + is called, passing in the :class:`SSHClientConnection` object. + + When the connection is closed, the method :meth:`connection_lost` + is called with an exception representing the reason for the + disconnect, or `None` if the connection was closed cleanly. + For simple password or public key based authentication, nothing needs to be defined here if the password or client keys are passed in when the connection is created. However, to prompt interactively @@ -212,11 +220,20 @@ def auth_banner_received(self, msg: str, lang: str) -> None: """ + def begin_auth(self, username: str) -> None: + """Begin client authentication + + This method is called when client authentication is about to + begin, Applications may store the username passed here to + be used in future authentication callbacks. + + """ + def auth_completed(self) -> None: """Authentication was completed successfully This method is called when authentication has completed - succesfully. Applications may use this method to create + successfully. Applications may use this method to create whatever client sessions and direct TCP/IP or UNIX domain connections are needed and/or set up listeners for incoming TCP/IP or UNIX domain connections coming from the server. diff --git a/asyncssh/compression.py b/asyncssh/compression.py index f014760..e05ca74 100644 --- a/asyncssh/compression.py +++ b/asyncssh/compression.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -149,9 +149,9 @@ def get_decompressor(alg: bytes) -> Optional[Decompressor]: return _cmp_decompressors[alg]() +register_compression_alg(b'none', + _none, _none, False, True) register_compression_alg(b'zlib@openssh.com', _ZLibCompress, _ZLibDecompress, True, True) register_compression_alg(b'zlib', _ZLibCompress, _ZLibDecompress, False, False) -register_compression_alg(b'none', - _none, _none, False, True) diff --git a/asyncssh/config.py b/asyncssh/config.py index c460817..0376b74 100644 --- a/asyncssh/config.py +++ b/asyncssh/config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022 by Ron Frederick and others. +# Copyright (c) 2020-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -41,6 +41,10 @@ ConfigPaths = Union[None, FilePath, Sequence[FilePath]] +_token_pattern = re.compile(r'%(.)') +_env_pattern = re.compile(r'\${(.*)}') + + def _exec(cmd: str) -> bool: """Execute a command and return if exit status is 0""" @@ -60,12 +64,15 @@ class SSHConfig: _percent_expand = {'AuthorizedKeysFile'} _handlers: Dict[str, Tuple[str, Callable]] = {} - def __init__(self, last_config: Optional['SSHConfig'], reload: bool): + def __init__(self, last_config: Optional['SSHConfig'], reload: bool, + canonical: bool, final: bool): if last_config: self._last_options = last_config.get_options(reload) else: self._last_options = {} + self._canonical = canonical + self._final = True if final else None self._default_path = Path('~', '.ssh').expanduser() self._path = Path() self._line_no = 0 @@ -75,11 +82,10 @@ def __init__(self, last_config: Optional['SSHConfig'], reload: bool): self.loaded = False - def _error(self, reason: str, *args: object) -> NoReturn: + def _error(self, reason: str) -> NoReturn: """Raise a configuration parsing error""" - raise ConfigParseError('%s line %s: %s' % (self._path, self._line_no, - reason % args)) + raise ConfigParseError(f'{self._path} line {self._line_no}: {reason}') def _match_val(self, match: str) -> object: """Return the value to match against in a match condition""" @@ -91,36 +97,38 @@ def _set_tokens(self) -> None: raise NotImplementedError - def _expand_val(self, value: str) -> str: - """Perform percent token expansion on a string""" + def _expand_token(self, match): + """Expand a percent token reference""" - last_idx = 0 - result: List[str] = [] + try: + token = match.group(1) + return self._tokens[token] + except KeyError: + if token == 'd': + raise ConfigParseError('Home directory is ' + 'not available') from None + elif token == 'i': + raise ConfigParseError('User id not available') from None + else: + raise ConfigParseError('Invalid token expansion: ' + + token) from None - for match in re.finditer(r'%', value): - idx = match.start() + @staticmethod + def _expand_env(match): + """Expand an environment variable reference""" - if idx < last_idx: - continue + try: + var = match.group(1) + return os.environ[var] + except KeyError: + raise ConfigParseError('Invalid environment expansion: ' + + var) from None - try: - token = value[idx+1] - result.extend([value[last_idx:idx], self._tokens[token]]) - last_idx = idx + 2 - except IndexError: - raise ConfigParseError('Invalid token substitution') from None - except KeyError: - if token == 'd': - raise ConfigParseError('Home directory is ' - 'not available') from None - elif token == 'i': - raise ConfigParseError('User id not available') from None - else: - raise ConfigParseError('Invalid token substitution: %s' % - value[idx+1]) from None - - result.append(value[last_idx:]) - return ''.join(result) + def _expand_val(self, value: str) -> str: + """Perform percent token and environment expansion on a string""" + + return _env_pattern.sub(self._expand_env, + _token_pattern.sub(self._expand_token, value)) def _include(self, option: str, args: List[str]) -> None: """Read config from a list of other config files""" @@ -141,7 +149,7 @@ def _include(self, option: str, args: List[str]) -> None: paths = list(path.glob(pattern)) if not paths: - logger.debug1('Config pattern "%s" matched no files', pattern) + logger.debug1(f'Config pattern "{pattern}" matched no files') for path in paths: self.parse(path) @@ -154,35 +162,53 @@ def _match(self, option: str, args: List[str]) -> None: # pylint: disable=unused-argument + matching = True + while args: match = args.pop(0).lower() - if match == 'all': - self._matching = True - continue + if match[0] == '!': + match = match[1:] + negated = True + else: + negated = False - match_val = self._match_val(match) + if match == 'final' and self._final is None: + self._final = False - if match != 'exec' and match_val is None: - self._error('Invalid match condition') + if match == 'all': + result = True + elif match == 'canonical': + result = self._canonical + elif match == 'final': + result = cast(bool, self._final) + else: + match_val = self._match_val(match) - try: - if match == 'exec': - self._matching = _exec(args.pop(0)) - elif match in ('address', 'localaddress'): - host_pat = HostPatternList(args.pop(0)) - ip = ip_address(cast(str, match_val)) \ - if match_val else None - self._matching = host_pat.matches(None, match_val, ip) - else: - wild_pat = WildcardPatternList(args.pop(0)) - self._matching = wild_pat.matches(match_val) - except IndexError: - self._error('Missing %s match pattern', match) - - if not self._matching: - args.clear() - break + if match != 'exec' and match_val is None: + self._error(f'Invalid match condition {match}') + + try: + arg = args.pop(0) + except IndexError: + self._error(f'Missing {match} match pattern') + + if matching: + if match == 'exec': + result = _exec(arg) + elif match in ('address', 'localaddress'): + host_pat = HostPatternList(arg) + ip = ip_address(cast(str, match_val)) \ + if match_val else None + result = host_pat.matches(None, match_val, ip) + else: + wild_pat = WildcardPatternList(arg) + result = wild_pat.matches(match_val) + + if matching and result == negated: + matching = False + + self._matching = matching def _set_bool(self, option: str, args: List[str]) -> None: """Set a boolean config option""" @@ -194,7 +220,23 @@ def _set_bool(self, option: str, args: List[str]) -> None: elif value_str in ('no', 'false'): value = False else: - self._error('Invalid %s boolean value: %s', option, value_str) + self._error(f'Invalid {option} boolean value: {value_str}') + + if option not in self._options: + self._options[option] = value + + def _set_bool_or_str(self, option: str, args: List[str]) -> None: + """Set a boolean or string config option""" + + value_str = args.pop(0) + value_lower = value_str.lower() + + if value_lower in ('yes', 'true'): + value: Union[bool, str] = True + elif value_lower in ('no', 'false'): + value = False + else: + value = value_str if option not in self._options: self._options[option] = value @@ -207,7 +249,7 @@ def _set_int(self, option: str, args: List[str]) -> None: try: value = int(value_str) except ValueError: - self._error('Invalid %s integer value: %s', option, value_str) + self._error(f'Invalid {option} integer value: {value_str}') if option not in self._options: self._options[option] = value @@ -243,7 +285,10 @@ def _set_string_list(self, option: str, args: List[str]) -> None: """Set whitespace-separated string config options as a list""" if option not in self._options: - self._options[option] = args[:] + if len(args) == 1 and args[0].lower() == 'none': + self._options[option] = [] + else: + self._options[option] = args[:] args.clear() @@ -269,7 +314,24 @@ def _set_address_family(self, option: str, args: List[str]) -> None: elif value_str == 'inet6': value = socket.AF_INET6 else: - self._error('Invalid %s value: %s', option, value_str) + self._error(f'Invalid {option} value: {value_str}') + + if option not in self._options: + self._options[option] = value + + def _set_canonicalize_host(self, option: str, args: List[str]) -> None: + """Set a canonicalize host config option""" + + value_str = args.pop(0).lower() + + if value_str in ('yes', 'true'): + value: Union[bool, str] = True + elif value_str in ('no', 'false'): + value = False + elif value_str == 'always': + value = value_str + else: + self._error(f'Invalid {option} value: {value_str}') if option not in self._options: self._options[option] = value @@ -293,6 +355,11 @@ def _set_rekey_limits(self, option: str, args: List[str]) -> None: if option not in self._options: self._options[option] = byte_limit, time_limit + def has_match_final(self) -> bool: + """Return whether this config includes a 'Match final' block""" + + return self._final is not None + def parse(self, path: Path) -> None: """Parse an OpenSSH config file and return matching declarations""" @@ -312,23 +379,33 @@ def parse(self, path: Path) -> None: continue try: - args = shlex.split(line) + split_args = shlex.split(line) except ValueError as exc: self._error(str(exc)) - option = args.pop(0) - - if option.endswith('='): - option = option[:-1] - elif '=' in option: - option, arg = option.split('=', 1) - args[:0] =[arg] - elif args and args[0] == '=': - del args[0] - elif args and args[0].startswith('='): - args[0] = args[0][1:] - - loption = option.lower() + args = [] + loption = '' + allow_equal = True + + for i, arg in enumerate(split_args, 1): + if arg.startswith('='): + if len(arg) > 1: + args.append(arg[1:]) + elif not allow_equal: + args.extend(split_args[i-1:]) + break + elif arg.endswith('='): + args.append(arg[:-1]) + elif '=' in arg: + arg, val = arg.split('=', 1) + args.append(arg) + args.append(val) + else: + args.append(arg) + + if i == 1: + loption = args.pop(0).lower() + allow_equal = loption in self._conditionals if loption in self._no_split: args = [line.lstrip()[len(loption):].strip()] @@ -342,12 +419,12 @@ def parse(self, path: Path) -> None: continue if not args: - self._error('Missing %s value', option) + self._error(f'Missing {option} value') handler(self, option, args) if args: - self._error('Extra data at end: %s', ' '.join(args)) + self._error(f'Extra data at end: {" ".join(args)}') self._set_tokens() @@ -372,10 +449,10 @@ def get_options(self, reload: bool) -> Dict[str, object]: @classmethod def load(cls, last_config: Optional['SSHConfig'], config_paths: ConfigPaths, reload: bool, - *args: object) -> 'SSHConfig': + canonical: bool, final: bool, *args: object) -> 'SSHConfig': """Load a list of OpenSSH config files into a config object""" - config = cls(last_config, reload, *args) + config = cls(last_config, reload, canonical, final, *args) if config_paths: if isinstance(config_paths, (str, PurePath)): @@ -412,13 +489,14 @@ class SSHClientConfig(SSHConfig): """Settings from an OpenSSH client config file""" _conditionals = {'host', 'match'} - _no_split = {'remotecommand'} - _percent_expand = {'CertificateFile', 'IdentityAgent', + _no_split = {'proxycommand', 'remotecommand'} + _percent_expand = {'CertificateFile', 'ForwardAgent', 'IdentityAgent', 'IdentityFile', 'ProxyCommand', 'RemoteCommand'} def __init__(self, last_config: 'SSHConfig', reload: bool, - local_user: str, user: str, host: str, port: int) -> None: - super().__init__(last_config, reload) + canonical: bool, final: bool, local_user: str, + user: str, host: str, port: int) -> None: + super().__init__(last_config, reload, canonical, final) self._local_user = local_user self._orig_host = host @@ -440,6 +518,8 @@ def _match_val(self, match: str) -> object: return self._local_user elif match == 'user': return self._options.get('User', self._local_user) + elif match == 'tagged': + return self._options.get('Tag', '') else: return None @@ -471,10 +551,10 @@ def _set_request_tty(self, option: str, args: List[str]) -> None: value: Union[bool, str] = True elif value_str in ('no', 'false'): value = False - elif value_str not in ('force', 'auto'): - self._error('Invalid %s value: %s', option, value_str) - else: + elif value_str in ('force', 'auto'): value = value_str + else: + self._error(f'Invalid {option} value: {value_str}') if option not in self._options: self._options[option] = value @@ -517,6 +597,11 @@ def _set_tokens(self) -> None: ('AddressFamily', SSHConfig._set_address_family), ('BindAddress', SSHConfig._set_string), + ('CanonicalDomains', SSHConfig._set_string_list), + ('CanonicalizeFallbackLocal', SSHConfig._set_bool), + ('CanonicalizeHostname', SSHConfig._set_canonicalize_host), + ('CanonicalizeMaxDots', SSHConfig._set_int), + ('CanonicalizePermittedCNAMEs', SSHConfig._set_string_list), ('CASignatureAlgorithms', SSHConfig._set_string), ('CertificateFile', SSHConfig._append_string), ('ChallengeResponseAuthentication', SSHConfig._set_bool), @@ -524,7 +609,7 @@ def _set_tokens(self) -> None: ('Compression', SSHConfig._set_bool), ('ConnectTimeout', SSHConfig._set_int), ('EnableSSHKeySign', SSHConfig._set_bool), - ('ForwardAgent', SSHConfig._set_bool), + ('ForwardAgent', SSHConfig._set_bool_or_str), ('ForwardX11Trusted', SSHConfig._set_bool), ('GlobalKnownHostsFile', SSHConfig._set_string_list), ('GSSAPIAuthentication', SSHConfig._set_bool), @@ -544,7 +629,7 @@ def _set_tokens(self) -> None: ('PKCS11Provider', SSHConfig._set_string), ('PreferredAuthentications', SSHConfig._set_string), ('Port', SSHConfig._set_int), - ('ProxyCommand', SSHConfig._set_string_list), + ('ProxyCommand', SSHConfig._set_string), ('ProxyJump', SSHConfig._set_string), ('PubkeyAuthentication', SSHConfig._set_bool), ('RekeyLimit', SSHConfig._set_rekey_limits), @@ -553,7 +638,8 @@ def _set_tokens(self) -> None: ('SendEnv', SSHConfig._append_string_list), ('ServerAliveCountMax', SSHConfig._set_int), ('ServerAliveInterval', SSHConfig._set_int), - ('SetEnv', SSHConfig._append_string_list), + ('SetEnv', SSHConfig._set_string_list), + ('Tag', SSHConfig._set_string), ('TCPKeepAlive', SSHConfig._set_bool), ('User', SSHConfig._set_string), ('UserKnownHostsFile', SSHConfig._set_string_list) @@ -564,9 +650,9 @@ class SSHServerConfig(SSHConfig): """Settings from an OpenSSH server config file""" def __init__(self, last_config: 'SSHConfig', reload: bool, - local_addr: str, local_port: int, user: str, - host: str, addr: str) -> None: - super().__init__(last_config, reload) + canonical: bool, final: bool, local_addr: str, + local_port: int, user: str, host: str, addr: str) -> None: + super().__init__(last_config, reload, canonical, final) self._local_addr = local_addr self._local_port = local_port @@ -603,6 +689,11 @@ def _set_tokens(self) -> None: ('AuthorizedKeysFile', SSHConfig._set_string_list), ('AllowAgentForwarding', SSHConfig._set_bool), ('BindAddress', SSHConfig._set_string), + ('CanonicalDomains', SSHConfig._set_string_list), + ('CanonicalizeFallbackLocal', SSHConfig._set_bool), + ('CanonicalizeHostname', SSHConfig._set_canonicalize_host), + ('CanonicalizeMaxDots', SSHConfig._set_int), + ('CanonicalizePermittedCNAMEs', SSHConfig._set_string_list), ('CASignatureAlgorithms', SSHConfig._set_string), ('ChallengeResponseAuthentication', SSHConfig._set_bool), ('Ciphers', SSHConfig._set_string), diff --git a/asyncssh/connection.py b/asyncssh/connection.py index 6246811..482b7e8 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2022 by Ron Frederick and others. +# Copyright (c) 2013-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -25,6 +25,7 @@ import getpass import inspect import io +import ipaddress import os import shlex import socket @@ -37,9 +38,9 @@ from pathlib import Path from types import TracebackType from typing import TYPE_CHECKING, Any, AnyStr, Awaitable, Callable, Dict -from typing import List, Mapping, Optional, Sequence, Set, Tuple, Type -from typing import TypeVar, Union, cast -from typing_extensions import Protocol +from typing import Generic, List, Mapping, Optional, Sequence, Set, Tuple +from typing import Type, TypeVar, Union, cast +from typing_extensions import Protocol, Self from .agent import SSHAgentClient, SSHAgentListener @@ -51,7 +52,7 @@ from .auth_keys import SSHAuthorizedKeys, read_authorized_keys from .channel import SSHChannel, SSHClientChannel, SSHServerChannel -from .channel import SSHTCPChannel, SSHUNIXChannel +from .channel import SSHTCPChannel, SSHUNIXChannel, SSHTunTapChannel from .channel import SSHX11Channel, SSHAgentChannel from .client import SSHClient @@ -105,21 +106,22 @@ from .mac import get_mac_algs, get_default_mac_algs -from .misc import BytesOrStr, DefTuple, FilePath, HostPort, IPNetwork -from .misc import MaybeAwait, OptExcInfo, Options, SockAddr +from .misc import BytesOrStr, BytesOrStrDict, DefTuple, Env, EnvSeq, FilePath +from .misc import HostPort, IPNetwork, MaybeAwait, OptExcInfo, Options, SockAddr from .misc import ChannelListenError, ChannelOpenError, CompressionError from .misc import DisconnectError, ConnectionLost, HostKeyNotVerifiable from .misc import KeyExchangeFailed, IllegalUserName, MACError from .misc import PasswordChangeRequired, PermissionDenied, ProtocolError from .misc import ProtocolNotSupported, ServiceNotAvailable -from .misc import async_context_manager, construct_disc_error -from .misc import get_symbol_names, ip_address, map_handler_name -from .misc import parse_byte_count, parse_time_interval +from .misc import TermModesArg, TermSizeArg +from .misc import async_context_manager, construct_disc_error, encode_env +from .misc import get_symbol_names, ip_address, lookup_env, map_handler_name +from .misc import parse_byte_count, parse_time_interval, split_args from .packet import Boolean, Byte, NameList, String, UInt32, PacketDecodeError from .packet import SSHPacket, SSHPacketHandler, SSHPacketLogger -from .pattern import WildcardPattern +from .pattern import WildcardPattern, WildcardPatternList from .pkcs11 import load_pkcs11_keys @@ -146,11 +148,10 @@ from .server import SSHServer -from .session import DataType, TermModesArg, TermSizeArg -from .session import SSHClientSession, SSHServerSession -from .session import SSHTCPSession, SSHUNIXSession +from .session import DataType, SSHClientSession, SSHServerSession +from .session import SSHTCPSession, SSHUNIXSession, SSHTunTapSession from .session import SSHClientSessionFactory, SSHTCPSessionFactory -from .session import SSHUNIXSessionFactory +from .session import SSHUNIXSessionFactory, SSHTunTapSessionFactory from .sftp import MIN_SFTP_VERSION, SFTPClient, SFTPServer from .sftp import start_sftp_client @@ -159,10 +160,14 @@ from .stream import SSHSocketSessionFactory, SSHServerSessionFactory from .stream import SSHClientStreamSession, SSHServerStreamSession from .stream import SSHTCPStreamSession, SSHUNIXStreamSession +from .stream import SSHTunTapStreamSession from .subprocess import SSHSubprocessTransport, SSHSubprocessProtocol from .subprocess import SubprocessFactory, SSHSubprocessWritePipe +from .tuntap import SSH_TUN_MODE_POINTTOPOINT, SSH_TUN_MODE_ETHERNET +from .tuntap import SSH_TUN_UNIT_ANY, create_tuntap + from .version import __version__ from .x11 import SSHX11ClientForwarder @@ -178,8 +183,12 @@ _ServerFactory = Callable[[], SSHServer] _ProtocolFactory = Union[_ClientFactory, _ServerFactory] -_Conn = TypeVar('_Conn', 'SSHClientConnection', 'SSHServerConnection') -_ConnectionFactory = Callable[[], _Conn] +_Conn = TypeVar('_Conn', bound='SSHConnection') +_Options = TypeVar('_Options', bound='SSHConnectionOptions') + +_ServerHostKeysHandler = Optional[Callable[[List[SSHKey], List[SSHKey], + List[SSHKey], List[SSHKey]], + MaybeAwait[None]]] class _TunnelProtocol(Protocol): """Base protocol for connections to tunnel SSH over""" @@ -215,15 +224,13 @@ async def create_server(self, session_factory: TCPListenerFactory, _AuthKeysArg = DefTuple[Union[None, str, List[str], SSHAuthorizedKeys]] _ClientHostKey = Union[SSHKeyPair, SSHKeySignKeyPair] _ClientKeysArg = Union[KeyListArg, KeyPairListArg] - -_Env = Optional[Union[Mapping[str, str], Sequence[str]]] -_SendEnv = Optional[Sequence[str]] +_CNAMEArg = DefTuple[Union[Sequence[str], Sequence[Tuple[str, str]]]] _GlobalRequest = Tuple[Optional[_PacketHandler], SSHPacket, bool] _GlobalRequestResult = Tuple[int, SSHPacket] _KeyOrCertOptions = Mapping[str, object] _ListenerArg = Union[bool, SSHListener] -_ProxyCommand = Optional[Sequence[str]] +_ProxyCommand = Optional[Union[str, Sequence[str]]] _RequestPTY = Union[bool, str] _TCPServerHandlerFactory = Callable[[str, int], SSHSocketSessionFactory] @@ -234,6 +241,7 @@ async def create_server(self, session_factory: TCPListenerFactory, _VersionArg = DefTuple[BytesOrStr] +SSHAcceptHandler = Callable[[str, int], MaybeAwait[bool]] # SSH service names _USERAUTH_SERVICE = b'ssh-userauth' @@ -267,9 +275,71 @@ async def create_server(self, session_factory: TCPListenerFactory, _DEFAULT_MAX_LINE_LENGTH = 1024 # 1024 characters +async def _canonicalize_host(loop: asyncio.AbstractEventLoop, + options: 'SSHConnectionOptions') -> Optional[str]: + """Canonicalize a host name""" + + host = options.host + + if not options.canonicalize_hostname or not options.canonical_domains: + logger.info('Host canonicalization disabled') + return None + + if host.count('.') > options.canonicalize_max_dots: + logger.info('Host canonicalization skipped due to max dots') + return None + + try: + ipaddress.ip_address(host) + except ValueError: + pass + else: + logger.info('Hostname canonicalization skipped on IP address') + return None + + logger.debug1('Beginning hostname canonicalization') + + for domain in options.canonical_domains: + logger.debug1(' Checking domain %s', domain) + + canon_host = f'{host}.{domain}' + + try: + addrinfo = await loop.getaddrinfo( + canon_host, 0, flags=socket.AI_CANONNAME) + except socket.gaierror: + continue + + cname = addrinfo[0][3] + + if cname and cname != canon_host: + logger.debug1(' Checking CNAME rules for hostname %s ' + 'with CNAME %s', canon_host, cname) + + for patterns in options.canonicalize_permitted_cnames: + host_pat, cname_pat = map(WildcardPatternList, patterns) + + if host_pat.matches(canon_host) and cname_pat.matches(cname): + logger.info('Hostname canonicalization to CNAME ' + 'applied: %s -> %s', options.host, cname) + return cname + + logger.info('Hostname canonicalization applied: %s -> %s', + options.host, canon_host) + + return canon_host + + if not options.canonicalize_fallback_local: + logger.info('Hostname canonicalization failed (fallback disabled)') + raise OSError(f'Unable to canonicalize hostname "{host}"') + + logger.info('Hostname canonicalization failed, using local resolver') + return None + + async def _open_proxy( loop: asyncio.AbstractEventLoop, command: Sequence[str], - conn_factory: _ConnectionFactory[_Conn]) -> _Conn: + conn_factory: Callable[[], _Conn]) -> _Conn: """Open a tunnel running a proxy command""" class _ProxyCommandTunnel(asyncio.SubprocessProtocol): @@ -283,12 +353,6 @@ def __init__(self) -> None: self._conn = conn_factory() self._close_event = asyncio.Event() - def set_protocol(self, protocol: asyncio.BaseProtocol) -> None: - """Changing the protocol is ignored here""" - - def get_protocol(self) -> asyncio.BaseProtocol: - """Changing the protocol is ignored here""" - def get_extra_info(self, name: str, default: Any = None) -> Any: """Return extra information associated with this tunnel""" @@ -323,12 +387,6 @@ def pipe_connection_lost(self, fd: int, self._conn.connection_lost(exc) - def is_closing(self) -> bool: - """Return whether the transport is closing or not""" - - assert self._transport is not None - return self._transport.is_closing() - def write(self, data: bytes) -> None: """Write data to this tunnel""" @@ -354,36 +412,60 @@ def close(self) -> None: return cast(_Conn, cast(_ProxyCommandTunnel, tunnel).get_conn()) -async def _open_tunnel(tunnel: object, passphrase: Optional[BytesOrStr]) -> \ +async def _open_tunnel(tunnels: object, options: _Options, + config: DefTuple[ConfigPaths]) -> \ Optional['SSHClientConnection']: """Parse and open connection to tunnel over""" username: DefTuple[str] port: DefTuple[int] - if isinstance(tunnel, str): - if '@' in tunnel: - username, host = tunnel.rsplit('@', 1) - else: - username, host = (), tunnel + if isinstance(tunnels, str): + conn: Optional[SSHClientConnection] = None - if ':' in host: - host, port_str = host.rsplit(':', 1) - port = int(port_str) - else: - port = () + for tunnel in tunnels.split(','): + if '@' in tunnel: + username, host = tunnel.rsplit('@', 1) + else: + username, host = (), tunnel + + if ':' in host: + host, port_str = host.rsplit(':', 1) + port = int(port_str) + else: + port = () + + last_conn = conn + conn = await connect(host, port, username=username, + passphrase=options.passphrase, tunnel=conn, + config=config) + conn.set_tunnel(last_conn) + + if options.canonicalize_hostname != 'always': + options.canonicalize_hostname = False - return await connect(host, port, username=username, - passphrase=passphrase) + return conn else: return None -async def _connect(options: 'SSHConnectionOptions', +async def _connect(options: _Options, config: DefTuple[ConfigPaths], loop: asyncio.AbstractEventLoop, flags: int, - conn_factory: _ConnectionFactory[_Conn], msg: str) -> _Conn: + sock: Optional[socket.socket], + conn_factory: Callable[[], _Conn], msg: str) -> _Conn: """Make outbound TCP or SSH tunneled connection""" + options.waiter = loop.create_future() + + canon_host = await _canonicalize_host(loop, options) + + host = canon_host if canon_host else options.host + canonical = bool(canon_host) + final = options.config.has_match_final() + + if canonical or final: + options.update(host=host, reload=True, canonical=canonical, final=final) + host = options.host port = options.port tunnel = options.tunnel @@ -392,13 +474,17 @@ async def _connect(options: 'SSHConnectionOptions', proxy_command = options.proxy_command free_conn = True - options.waiter = loop.create_future() - - new_tunnel = await _open_tunnel(tunnel, options.passphrase) + new_tunnel = await _open_tunnel(tunnel, options, config) tunnel: _TunnelConnectorProtocol try: - if new_tunnel: + if sock: + logger.info('%s already-connected socket', msg) + + _, session = await loop.create_connection(conn_factory, sock=sock) + + conn = cast(_Conn, session) + elif new_tunnel: new_tunnel.logger.info('%s %s via %s', msg, (host, port), tunnel) # pylint: disable=broad-except @@ -436,13 +522,11 @@ async def _connect(options: 'SSHConnectionOptions', options.waiter.cancel() raise + conn.set_extra_info(host=host, port=port) + try: await options.waiter free_conn = False - - if new_tunnel: - conn.set_tunnel(new_tunnel) - return conn finally: if free_conn: @@ -450,10 +534,11 @@ async def _connect(options: 'SSHConnectionOptions', await conn.wait_closed() -async def _listen(options: 'SSHConnectionOptions', - loop: asyncio.AbstractEventLoop, flags: int, backlog: int, +async def _listen(options: _Options, config: DefTuple[ConfigPaths], + loop: asyncio.AbstractEventLoop, flags: int, + backlog: int, sock: Optional[socket.socket], reuse_address: bool, reuse_port: bool, - conn_factory: _ConnectionFactory[_Conn], + conn_factory: Callable[[], _Conn], msg: str) -> 'SSHAcceptor': """Make inbound TCP or SSH tunneled listener""" @@ -467,10 +552,16 @@ def tunnel_factory(_orig_host: str, _orig_port: int) -> SSHTCPSession: tunnel = options.tunnel family = options.family - new_tunnel = await _open_tunnel(tunnel, options.passphrase) + new_tunnel = await _open_tunnel(tunnel, options, config) tunnel: _TunnelListenerProtocol - if new_tunnel: + if sock: + logger.info('%s already-connected socket', msg) + + server: asyncio.AbstractServer = await loop.create_server( + conn_factory, sock=sock, backlog=backlog, + reuse_address=reuse_address, reuse_port=reuse_port) + elif new_tunnel: new_tunnel.logger.info('%s %s via %s', msg, (host, port), tunnel) # pylint: disable=broad-except @@ -501,14 +592,6 @@ def tunnel_factory(_orig_host: str, _orig_port: int) -> SSHTCPSession: return SSHAcceptor(server, options) -async def _run_in_executor(loop: asyncio.AbstractEventLoop, func: Callable, - *args: object, **kwargs: object) -> object: - """Run a potentially blocking call in an executor""" - - return await loop.run_in_executor( - None, functools.partial(func, *args, **kwargs)) - - def _validate_version(version: DefTuple[BytesOrStr]) -> bytes: """Validate requested SSH version""" @@ -552,8 +635,7 @@ def _expand_algs(alg_type: str, algs: str, if pattern.matches(alg.decode('ascii'))] if not matches and strict_match: - raise ValueError('"%s" matches no valid %s algorithms' % - (pat, alg_type)) + raise ValueError(f'"{pat}" matches no valid {alg_type} algorithms') matched.extend(matches) @@ -591,8 +673,8 @@ def _select_algs(alg_type: str, algs: _AlgsArg, config_algs: _AlgsArg, for alg in expanded_algs: if alg not in possible_algs: - raise ValueError('%s is not a valid %s algorithm' % - (alg.decode('ascii'), alg_type)) + raise ValueError(f'{alg.decode("ascii")} is not a valid ' + f'{alg_type} algorithm') if alg not in result: result.append(alg) @@ -601,7 +683,7 @@ def _select_algs(alg_type: str, algs: _AlgsArg, config_algs: _AlgsArg, elif none_value: return [none_value] else: - raise ValueError('No %s algorithms selected' % alg_type) + raise ValueError(f'No {alg_type} algorithms selected') def _select_host_key_algs(algs: _AlgsArg, config_algs: _AlgsArg, @@ -657,7 +739,7 @@ class SSHAcceptor: This class in a wrapper around an :class:`asyncio.Server` listener which provides the ability to update the the set of SSH client or - server connection options associated wtih that listener. This is + server connection options associated with that listener. This is accomplished by calling the :meth:`update` method, which takes the same keyword arguments as the :class:`SSHClientConnectionOptions` and :class:`SSHServerConnectionOptions` classes. @@ -672,7 +754,7 @@ def __init__(self, server: asyncio.AbstractServer, self._server = server self._options = options - async def __aenter__(self) -> 'SSHAcceptor': + async def __aenter__(self) -> Self: return self async def __aexit__(self, _exc_type: Optional[Type[BaseException]], @@ -685,13 +767,61 @@ async def __aexit__(self, _exc_type: Optional[Type[BaseException]], def __getattr__(self, name: str) -> Any: return getattr(self._server, name) + def get_addresses(self) -> List[Tuple]: + """Return socket addresses being listened on + + This method returns the socket addresses being listened on. + It returns tuples of the form returned by + :meth:`socket.getsockname`. If the listener was created + using a hostname, the host's resolved IPs will be returned. + If the requested listening port was `0`, the selected + listening ports will be returned. + + :returns: A list of socket addresses being listened on + + """ + + if hasattr(self._server, 'get_addresses'): + return self._server.get_addresses() + else: + return [sock.getsockname() for sock in self.sockets] + + def get_port(self) -> int: + """Return the port number being listened on + + This method returns the port number being listened on. + If it is listening on multiple sockets with different port + numbers, this function will return `0`. In that case, + :meth:`get_addresses` can be used to retrieve the full + list of listening addresses and ports. + + :returns: The port number being listened on, if there's only one + + """ + + if hasattr(self._server, 'get_port'): + return self._server.get_port() + else: + ports = {addr[1] for addr in self.get_addresses()} + return ports.pop() if len(ports) == 1 else 0 + def close(self) -> None: - """Close this SSH listener""" + """Stop listening for new connections + + This method can be called to stop listening for new + SSH connections. Existing connections will remain open. + + """ self._server.close() async def wait_closed(self) -> None: - """Wait for this SSH listener to close""" + """Wait for this listener to close + + This method is a coroutine which waits for this + listener to be closed. + + """ await self._server.wait_closed() @@ -709,7 +839,7 @@ def update(self, **kwargs: object) -> None: """ - self._options.update(kwargs) + self._options.update(**kwargs) class SSHConnection(SSHPacketHandler, asyncio.Protocol): @@ -810,6 +940,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._kexinit_sent = False self._kex_complete = False self._ignore_first_kex = False + self._strict_kex = False self._gss: Optional[GSSBase] = None self._gss_kex = False @@ -843,6 +974,8 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._can_send_ext_info = False self._extensions_to_send: 'OrderedDict[bytes, bytes]' = OrderedDict() + self._can_recv_ext_info = False + self._server_sig_algs: Set[bytes] = set() self._next_service: Optional[bytes] = None @@ -852,6 +985,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._auth: Optional[Auth] = None self._auth_in_progress = False self._auth_complete = False + self._auth_final = False self._auth_methods = [b'none'] self._auth_was_trivial = True self._username = '' @@ -868,12 +1002,13 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._x11_listener: Union[None, SSHX11ClientListener, SSHX11ServerListener] = None + self._tasks: Set[asyncio.Task[None]] = set() self._close_event = asyncio.Event() self._server_host_key_algs: Optional[Sequence[bytes]] = None - self._logger = logger.get_child(context='conn=%d' % - self._get_next_conn()) + self._logger = logger.get_child( + context=f'conn={self._get_next_conn()}') self._login_timer: Optional[asyncio.TimerHandle] @@ -885,7 +1020,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._disable_trivial_auth = False - async def __aenter__(self) -> 'SSHConnection': + async def __aenter__(self) -> Self: """Allow SSHConnection to be used as an async context manager""" return self @@ -940,7 +1075,13 @@ def _cleanup(self, exc: Optional[Exception]) -> None: self._wait = None if self._owner: # pragma: no branch - self._owner.connection_lost(exc) + # pylint: disable=broad-except + try: + self._owner.connection_lost(exc) + except Exception: + self.logger.debug1('Uncaught exception in owner ignored', + exc_info=sys.exc_info) + self._owner = None self._cancel_login_timer() @@ -1012,21 +1153,23 @@ def _keepalive_timer_callback(self) -> None: self._set_keepalive_timer() self.create_task(self._make_keepalive_request()) - def _force_close(self, exc: Optional[BaseException]) -> None: + def _force_close(self, exc: Optional[Exception]) -> None: """Force this connection to close immediately""" if not self._transport: return - self._transport.abort() + self._loop.call_soon(self._transport.abort) self._transport = None self._loop.call_soon(self._cleanup, exc) - def _reap_task(self, task_logger: SSHLogger, + def _reap_task(self, task_logger: Optional[SSHLogger], task: 'asyncio.Task[None]') -> None: """Collect result of an async task, reporting errors""" + self._tasks.discard(task) + # pylint: disable=broad-except try: task.result() @@ -1039,11 +1182,14 @@ def _reap_task(self, task_logger: SSHLogger, self.internal_error(error_logger=task_logger) def create_task(self, coro: Awaitable[None], - task_logger: SSHLogger = None) -> 'asyncio.Task[None]': + task_logger: Optional[SSHLogger] = None) -> \ + 'asyncio.Task[None]': """Create an asynchronous task which catches and reports errors""" task = asyncio.ensure_future(coro) task.add_done_callback(partial(self._reap_task, task_logger)) + self._tasks.add(task) + return task def is_client(self) -> bool: @@ -1056,6 +1202,11 @@ def is_server(self) -> bool: return self._server + def is_closed(self) -> bool: + """Return whether the connection is closed""" + + return self._close_event.is_set() + def get_owner(self) -> Optional[Union[SSHClient, SSHServer]]: """Return the SSHClient or SSHServer which owns this connection""" @@ -1075,7 +1226,7 @@ def get_hash_prefix(self) -> bytes: String(self._client_kexinit), String(self._server_kexinit))) - def set_tunnel(self, tunnel: _TunnelProtocol) -> None: + def set_tunnel(self, tunnel: Optional[_TunnelProtocol]) -> None: """Set tunnel used to open this connection""" self._tunnel = tunnel @@ -1205,8 +1356,10 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: if sock: sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, - self._tcp_keepalive) - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + 1 if self._tcp_keepalive else 0) + + if sock.family in (socket.AF_INET, socket.AF_INET6): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sockname = cast(SockAddr, transport.get_extra_info('sockname')) @@ -1247,7 +1400,7 @@ def internal_error(self, exc_info: Optional[OptExcInfo] = None, error_logger = self.logger error_logger.debug1('Uncaught exception', exc_info=exc_info) - self._force_close(exc_info[1]) + self._force_close(cast(Exception, exc_info[1])) def session_started(self) -> None: """Handle session start when opening tunneled SSH connection""" @@ -1260,17 +1413,7 @@ def data_received(self, data: bytes, datatype: DataType = None) -> None: self._inpbuf += data - self._reset_keepalive_timer() - - # pylint: disable=broad-except - try: - while self._inpbuf and self._recv_handler(): - pass - except DisconnectError as exc: - self._send_disconnect(exc.code, exc.reason, exc.lang) - self._force_close(exc) - except Exception: - self.internal_error() + self._recv_data() # pylint: enable=arguments-differ def eof_received(self) -> None: @@ -1339,23 +1482,26 @@ def _choose_alg(self, alg_type: str, local_algs: Sequence[bytes], return alg raise KeyExchangeFailed( - 'No matching %s algorithm found, sent %s and received %s' % - (alg_type, b','.join(local_algs).decode('ascii'), - b','.join(remote_algs).decode('ascii'))) + f'No matching {alg_type} algorithm found, sent ' + f'{b",".join(local_algs).decode("ascii")} and received ' + f'{b",".join(remote_algs).decode("ascii")}') - def _get_ext_info_kex_alg(self) -> List[bytes]: - """Return the kex alg to add if any to request extension info""" + def _get_extra_kex_algs(self) -> List[bytes]: + """Return the extra kex algs to add""" - return [b'ext-info-c' if self.is_client() else b'ext-info-s'] + if self.is_client(): + return [b'ext-info-c', b'kex-strict-c-v00@openssh.com'] + else: + return [b'ext-info-s', b'kex-strict-s-v00@openssh.com'] def _send(self, data: bytes) -> None: """Send data to the SSH connection""" if self._transport: - if self._transport.is_closing(): - self._force_close(BrokenPipeError()) - else: + try: self._transport.write(data) + except ConnectionError: # pragma: no cover + pass def _send_version(self) -> None: """Start the SSH handshake""" @@ -1373,6 +1519,21 @@ def _send_version(self) -> None: self._send(version + b'\r\n') + def _recv_data(self) -> None: + """Parse received data""" + + self._reset_keepalive_timer() + + # pylint: disable=broad-except + try: + while self._inpbuf and self._recv_handler(): + pass + except DisconnectError as exc: + self._send_disconnect(exc.code, exc.reason, exc.lang) + self._force_close(exc) + except Exception: + self.internal_error() + def _recv_version(self) -> bool: """Receive and parse the remote SSH version""" @@ -1481,15 +1642,30 @@ def _recv_packet(self) -> bool: skip_reason = '' exc_reason = '' - if self._kex and MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST: - if self._ignore_first_kex: # pragma: no cover - skip_reason = 'ignored first kex' - self._ignore_first_kex = False + if MSG_KEX_FIRST <= pkttype <= MSG_KEX_LAST: + if self._kex: + if self._ignore_first_kex: # pragma: no cover + skip_reason = 'ignored first kex' + self._ignore_first_kex = False + else: + handler = self._kex + else: + skip_reason = 'kex not in progress' + exc_reason = 'Key exchange not in progress' + elif self._strict_kex and not self._recv_encryption and \ + MSG_IGNORE <= pkttype <= MSG_DEBUG: + skip_reason = 'strict kex violation' + exc_reason = 'Strict key exchange violation: ' \ + f'unexpected packet type {pkttype} received' + elif MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST: + if self._auth: + handler = self._auth else: - handler = self._kex - elif (self._auth and - MSG_USERAUTH_FIRST <= pkttype <= MSG_USERAUTH_LAST): - handler = self._auth + skip_reason = 'auth not in progress' + exc_reason = 'Authentication not in progress' + elif pkttype > MSG_KEX_LAST and not self._recv_encryption: + skip_reason = 'invalid request before kex complete' + exc_reason = 'Invalid request before key exchange was complete' elif pkttype > MSG_USERAUTH_LAST and not self._auth_complete: skip_reason = 'invalid request before auth complete' exc_reason = 'Invalid request before authentication was complete' @@ -1504,32 +1680,64 @@ def _recv_packet(self) -> bool: handler = self._channels[recv_chan] except KeyError: skip_reason = 'invalid channel number' - exc_reason = 'Invalid channel number %d ' \ - 'received' % recv_chan + exc_reason = f'Invalid channel number {recv_chan} received' handler.log_received_packet(pkttype, seq, packet, skip_reason) if not skip_reason: try: - processed = handler.process_packet(pkttype, seq, packet) + result = handler.process_packet(pkttype, seq, packet) except PacketDecodeError as exc: raise ProtocolError(str(exc)) from None - if not processed: - self.logger.debug1('Unknown packet type %d received', pkttype) - self.send_packet(MSG_UNIMPLEMENTED, UInt32(seq)) + if inspect.isawaitable(result): + # Buffer received data until current packet is processed + self._recv_handler = lambda: False + + task = self.create_task(result) + task.add_done_callback(functools.partial( + self._finish_recv_packet, pkttype, seq, is_async=True)) + + return False + elif not result: + if self._strict_kex and not self._recv_encryption: + exc_reason = 'Strict key exchange violation: ' \ + f'unexpected packet type {pkttype} received' + else: + self.logger.debug1('Unknown packet type %d received', + pkttype) + self.send_packet(MSG_UNIMPLEMENTED, UInt32(seq)) if exc_reason: raise ProtocolError(exc_reason) + self._finish_recv_packet(pkttype, seq) + return True + + def _finish_recv_packet(self, pkttype: int, seq: int, + _task: Optional[asyncio.Task] = None, + is_async: bool = False) -> None: + """Finish processing a packet""" + + if pkttype > MSG_USERAUTH_LAST: + self._auth_final = True + if self._transport: - self._recv_seq = (seq + 1) & 0xffffffff - self._recv_handler = self._recv_pkthdr + if self._recv_seq == 0xffffffff and not self._recv_encryption: + raise ProtocolError('Sequence rollover before kex complete') - return True + if pkttype == MSG_NEWKEYS and self._strict_kex: + self._recv_seq = 0 + else: + self._recv_seq = (seq + 1) & 0xffffffff + + self._recv_handler = self._recv_pkthdr + + if is_async and self._inpbuf: + self._recv_data() def send_packet(self, pkttype: int, *args: bytes, - handler: SSHPacketLogger = None) -> None: + handler: Optional[SSHPacketLogger] = None) -> None: """Send an SSH packet""" if (self._auth_complete and self._kex_complete and @@ -1539,7 +1747,7 @@ def send_packet(self, pkttype: int, *args: bytes, self._send_kexinit() self._kexinit_sent = True - if (((pkttype in {MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT} or + if (((pkttype in {MSG_DEBUG, MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT} or pkttype > MSG_KEX_LAST) and not self._kex_complete) or (pkttype == MSG_USERAUTH_BANNER and not (self._auth_in_progress or self._auth_complete)) or @@ -1549,7 +1757,7 @@ def send_packet(self, pkttype: int, *args: bytes, # If we're encrypting and we have no data outstanding, insert an # ignore packet into the stream - if self._send_encryption and pkttype not in (MSG_IGNORE, MSG_EXT_INFO): + if self._send_encryption and pkttype > MSG_KEX_LAST: self.send_packet(MSG_IGNORE, String(b'')) orig_payload = Byte(pkttype) + b''.join(args) @@ -1579,7 +1787,15 @@ def send_packet(self, pkttype: int, *args: bytes, mac = b'' self._send(packet + mac) - self._send_seq = (seq + 1) & 0xffffffff + + if self._send_seq == 0xffffffff and not self._send_encryption: + self._send_seq = 0 + raise ProtocolError('Sequence rollover before kex complete') + + if pkttype == MSG_NEWKEYS and self._strict_kex: + self._send_seq = 0 + else: + self._send_seq = (seq + 1) & 0xffffffff if self._kex_complete: self._rekey_bytes_sent += pktlen @@ -1623,7 +1839,7 @@ def _send_kexinit(self) -> None: kex_algs = expand_kex_algs(self._kex_algs, gss_mechs, bool(self._server_host_key_algs)) + \ - self._get_ext_info_kex_alg() + self._get_extra_kex_algs() host_key_algs = self._server_host_key_algs or [b'null'] @@ -1667,7 +1883,7 @@ def _send_ext_info(self) -> None: self.send_packet(MSG_EXT_INFO, packet) - def send_newkeys(self, k: int, h: bytes) -> None: + def send_newkeys(self, k: bytes, h: bytes) -> None: """Finish a key exchange and send a new keys message""" if not self._session_id: @@ -1753,9 +1969,11 @@ def send_newkeys(self, k: int, h: bytes) -> None: not self._waiter.cancelled(): self._waiter.set_result(None) self._wait = None - else: - self.send_service_request(_USERAUTH_SERVICE) + return else: + self._extensions_to_send[b'server-sig-algs'] = \ + b','.join(self._sig_algs) + self._send_encryption = next_enc_sc self._send_enchdrlen = 1 if etm_sc else 5 self._send_blocksize = max(8, enc_blocksize_sc) @@ -1776,17 +1994,18 @@ def send_newkeys(self, k: int, h: bytes) -> None: recv_mac=self._mac_alg_cs.decode('ascii'), recv_compression=self._cmp_alg_cs.decode('ascii')) - if first_kex: - self._next_service = _USERAUTH_SERVICE - - self._extensions_to_send[b'server-sig-algs'] = \ - b','.join(self._sig_algs) - if self._can_send_ext_info: self._send_ext_info() self._can_send_ext_info = False self._kex_complete = True + + if first_kex: + if self.is_client(): + self.send_service_request(_USERAUTH_SERVICE) + else: + self._next_service = _USERAUTH_SERVICE + self._send_deferred_packets() def send_service_request(self, service: bytes) -> None: @@ -1811,7 +2030,7 @@ def get_userauth_request_data(self, method: bytes, *args: bytes) -> bytes: self._get_userauth_request_packet(method, args)) def send_userauth_packet(self, pkttype: int, *args: bytes, - handler: SSHPacketLogger = None, + handler: Optional[SSHPacketLogger] = None, trivial: bool = True) -> None: """Send a user authentication packet""" @@ -1856,7 +2075,7 @@ def send_userauth_failure(self, partial_success: bool) -> None: self.send_packet(MSG_USERAUTH_FAILURE, NameList(methods), Boolean(partial_success)) - def send_userauth_success(self) -> None: + async def send_userauth_success(self) -> None: """Send a user authentication success response""" self.logger.info('Auth for user %s succeeded', self._username) @@ -1873,13 +2092,15 @@ def send_userauth_success(self) -> None: self._set_keepalive_timer() if self._owner: # pragma: no branch - self._owner.auth_completed() + result = self._owner.auth_completed() + + if inspect.isawaitable(result): + await result if self._acceptor: result = self._acceptor(self) if inspect.isawaitable(result): - assert result is not None self.create_task(result) self._acceptor = None @@ -1889,6 +2110,11 @@ def send_userauth_success(self) -> None: not self._waiter.cancelled(): self._waiter.set_result(None) self._wait = None + return + + # This method is only in SSHServerConnection + # pylint: disable=no-member + cast(SSHServerConnection, self).send_server_host_keys() def send_channel_open_confirmation(self, send_chan: int, recv_chan: int, recv_window: int, recv_pktsize: int, @@ -1906,6 +2132,13 @@ def send_channel_open_failure(self, send_chan: int, code: int, self.send_packet(MSG_CHANNEL_OPEN_FAILURE, UInt32(send_chan), UInt32(code), String(reason), String(lang)) + def _send_global_request(self, request: bytes, *args: bytes, + want_reply: bool = False) -> None: + """Send a global request""" + + self.send_packet(MSG_GLOBAL_REQUEST, String(request), + Boolean(want_reply), *args) + async def _make_global_request(self, request: bytes, *args: bytes) -> Tuple[int, SSHPacket]: """Send a global request and wait for the response""" @@ -1918,8 +2151,7 @@ async def _make_global_request(self, request: bytes, self._global_request_waiters.append(waiter) - self.send_packet(MSG_GLOBAL_REQUEST, String(request), - Boolean(True), *args) + self._send_global_request(request, *args, want_reply=True) return await waiter @@ -1980,10 +2212,11 @@ def _process_ignore(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process an ignore message""" - # pylint: disable=no-self-use - - _ = packet.get_string() # data - packet.check_end() + # Work around missing payload bytes in an ignore message + # in some Cisco SSH servers + if b'Cisco' not in self._server_version: # pragma: no branch + _ = packet.get_string() # data + packet.check_end() def _process_unimplemented(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: @@ -2022,18 +2255,25 @@ def _process_service_request(self, _pkttype: int, _pktid: int, service = packet.get_string() packet.check_end() - if service == self._next_service: - self.logger.debug2('Accepting request for service %s', service) + if self.is_client(): + raise ProtocolError('Unexpected service request received') - self.send_packet(MSG_SERVICE_ACCEPT, String(service)) + if not self._recv_encryption: + raise ProtocolError('Service request received before kex complete') - if (self.is_server() and # pragma: no branch - not self._auth_in_progress and - service == _USERAUTH_SERVICE): - self._auth_in_progress = True - self._send_deferred_packets() - else: - raise ServiceNotAvailable('Unexpected service request received') + if service != self._next_service: + raise ServiceNotAvailable('Unexpected service in service request') + + self.logger.debug2('Accepting request for service %s', service) + + self.send_packet(MSG_SERVICE_ACCEPT, String(service)) + + self._next_service = None + + if service == _USERAUTH_SERVICE: # pragma: no branch + self._auth_in_progress = True + self._can_recv_ext_info = False + self._send_deferred_packets() def _process_service_accept(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: @@ -2042,27 +2282,38 @@ def _process_service_accept(self, _pkttype: int, _pktid: int, service = packet.get_string() packet.check_end() - if service == self._next_service: - self.logger.debug2('Request for service %s accepted', service) + if self.is_server(): + raise ProtocolError('Unexpected service accept received') - self._next_service = None + if not self._recv_encryption: + raise ProtocolError('Service accept received before kex complete') - if (self.is_client() and # pragma: no branch - service == _USERAUTH_SERVICE): - self.logger.info('Beginning auth for user %s', self._username) + if service != self._next_service: + raise ServiceNotAvailable('Unexpected service in service accept') - self._auth_in_progress = True + self.logger.debug2('Request for service %s accepted', service) - # This method is only in SSHClientConnection - # pylint: disable=no-member - cast('SSHClientConnection', self).try_next_auth() - else: - raise ServiceNotAvailable('Unexpected service accept received') + self._next_service = None + + if service == _USERAUTH_SERVICE: # pragma: no branch + self.logger.info('Beginning auth for user %s', self._username) + + self._auth_in_progress = True + + if self._owner: # pragma: no branch + self._owner.begin_auth(self._username) + + # This method is only in SSHClientConnection + # pylint: disable=no-member + cast('SSHClientConnection', self).try_next_auth() def _process_ext_info(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process extension information""" + if not self._can_recv_ext_info: + raise ProtocolError('Unexpected ext_info received') + extensions: Dict[bytes, bytes] = {} self.logger.debug2('Received extension info') @@ -2081,8 +2332,8 @@ def _process_ext_info(self, _pkttype: int, _pktid: int, self._server_sig_algs = \ set(extensions.get(b'server-sig-algs', b'').split(b',')) - def _process_kexinit(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_kexinit(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a key exchange request""" if self._kex: @@ -2106,13 +2357,27 @@ def _process_kexinit(self, _pkttype: int, _pktid: int, if self.is_server(): self._client_kexinit = packet.get_consumed_payload() - if b'ext-info-c' in peer_kex_algs and not self._session_id: - self._can_send_ext_info = True + if not self._session_id: + if b'ext-info-c' in peer_kex_algs: + self._can_send_ext_info = True + + if b'kex-strict-c-v00@openssh.com' in peer_kex_algs: + self._strict_kex = True else: self._server_kexinit = packet.get_consumed_payload() - if b'ext-info-s' in peer_kex_algs and not self._session_id: - self._can_send_ext_info = True + if not self._session_id: + if b'ext-info-s' in peer_kex_algs: + self._can_send_ext_info = True + + if b'kex-strict-s-v00@openssh.com' in peer_kex_algs: + self._strict_kex = True + + if self._strict_kex and not self._recv_encryption and \ + self._recv_seq != 0: + raise ProtocolError('Strict key exchange violation: ' + 'KEXINIT was not the first packet') + if self._kexinit_sent: self._kexinit_sent = False @@ -2172,7 +2437,7 @@ def _process_kexinit(self, _pkttype: int, _pktid: int, self.logger.debug1('Beginning key exchange') self.logger.debug2(' Key exchange alg: %s', self._kex.algorithm) - self._kex.start() + await self._kex.start() def _process_newkeys(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: @@ -2188,6 +2453,7 @@ def _process_newkeys(self, _pkttype: int, _pktid: int, self._decompress_after_auth = self._next_decompress_after_auth self._next_recv_encryption = None + self._can_recv_ext_info = True else: raise ProtocolError('New keys not negotiated') @@ -2215,8 +2481,10 @@ def _process_userauth_request(self, _pkttype: int, _pktid: int, if self.is_client(): raise ProtocolError('Unexpected userauth request') elif self._auth_complete: - # Silently ignore requests if we're already authenticated - pass + # Silently ignore additional auth requests after auth succeeds, + # until the client sends a non-auth message + if self._auth_final: + raise ProtocolError('Unexpected userauth request') else: if username != self._username: self.logger.info('Beginning auth for user %s', username) @@ -2246,7 +2514,7 @@ async def _finish_userauth(self, begin_auth: bool, method: bytes, result = await cast(Awaitable[bool], result) if not result: - self.send_userauth_success() + await self.send_userauth_success() return if not self._owner: # pragma: no cover @@ -2258,7 +2526,7 @@ async def _finish_userauth(self, begin_auth: bool, method: bytes, self._auth = lookup_server_auth(cast(SSHServerConnection, self), self._username, method, packet) - def _process_userauth_failure(self, _pkttype: int, pktid: int, + def _process_userauth_failure(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a user authentication failure response""" @@ -2298,10 +2566,9 @@ def _process_userauth_failure(self, _pkttype: int, pktid: int, # pylint: disable=no-member cast(SSHClientConnection, self).try_next_auth() else: - self.logger.debug2('Unexpected userauth failure response') - self.send_packet(MSG_UNIMPLEMENTED, UInt32(pktid)) + raise ProtocolError('Unexpected userauth failure response') - def _process_userauth_success(self, _pkttype: int, pktid: int, + def _process_userauth_success(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a user authentication success response""" @@ -2327,6 +2594,7 @@ def _process_userauth_success(self, _pkttype: int, pktid: int, self._auth = None self._auth_in_progress = False self._auth_complete = True + self._can_recv_ext_info = False if self._agent: self._agent.close() @@ -2343,7 +2611,6 @@ def _process_userauth_success(self, _pkttype: int, pktid: int, result = self._acceptor(self) if inspect.isawaitable(result): - assert result is not None self.create_task(result) self._acceptor = None @@ -2354,8 +2621,7 @@ def _process_userauth_success(self, _pkttype: int, pktid: int, self._waiter.set_result(None) self._wait = None else: - self.logger.debug2('Unexpected userauth success response') - self.send_packet(MSG_UNIMPLEMENTED, UInt32(pktid)) + raise ProtocolError('Unexpected userauth success response') def _process_userauth_banner(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: @@ -2363,6 +2629,13 @@ def _process_userauth_banner(self, _pkttype: int, _pktid: int, msg_bytes = packet.get_string() lang_bytes = packet.get_string() + + # Work around an extra NUL byte appearing in the user + # auth banner message in some versions of cryptlib + if b'cryptlib' in self._server_version and \ + packet.get_remaining_payload() == b'\0': # pragma: no cover + packet.get_byte() + packet.check_end() try: @@ -2536,7 +2809,7 @@ def abort(self) -> None: """Forcibly close the SSH connection This method closes the SSH connection immediately, without - waiting for pending operations to complete and wihtout sending + waiting for pending operations to complete and without sending an explicit SSH disconnect message. Buffered data waiting to be sent will be lost and no more data will be received. When the the connection is closed, :meth:`connection_lost() @@ -2612,6 +2885,8 @@ def get_extra_info(self, name: str, default: Any = None) -> Any: it is established. Supported values include everything supported by a socket transport plus: + | host + | port | username | client_version | server_version @@ -2783,6 +3058,31 @@ def create_unix_channel(self, encoding: Optional[str] = None, return SSHUNIXChannel(self, self._loop, encoding, errors, window, max_pktsize) + def create_tuntap_channel(self, window: int = _DEFAULT_WINDOW, + max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ + SSHTunTapChannel: + """Create a channel to use for TUN/TAP forwarding + + This method can be called by :meth:`tun_requested() + ` or :meth:`tap_requested() + ` to create an :class:`SSHTunTapChannel` + with the desired window and max packet size for a newly created + TUN/TAP tunnel. + + :param window: (optional) + The receive window size for this session + :param max_pktsize: (optional) + The maximum packet size for this session + :type window: `int` + :type max_pktsize: `int` + + :returns: :class:`SSHTunTapChannel` + + """ + + return SSHTunTapChannel(self, self._loop, None, 'strict', + window, max_pktsize) + def create_x11_channel( self, window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> SSHX11Channel: @@ -2877,10 +3177,10 @@ async def forward_unix_connection(self, dest_path: str) -> SSHForwarder: return SSHForwarder(cast(SSHForwarder, peer)) @async_context_manager - async def forward_local_port(self, listen_host: str, - listen_port: int, - dest_host: str, - dest_port: int) -> SSHListener: + async def forward_local_port( + self, listen_host: str, listen_port: int, + dest_host: str, dest_port: int, + accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener: """Set up local port forwarding This method is a coroutine which attempts to set up port @@ -2897,10 +3197,17 @@ async def forward_local_port(self, listen_host: str, The hostname or address to forward the connections to :param dest_port: The port number to forward the connections to + :param accept_handler: + A `callable` or coroutine which takes arguments of the + original host and port of the client and decides whether + or not to allow connection forwarding, returning `True` to + accept the connection and begin forwarding or `False` to + reject and close it. :type listen_host: `str` :type listen_port: `int` :type dest_host: `str` :type dest_port: `int` + :type accept_handler: `callable` or coroutine :returns: :class:`SSHListener` @@ -2914,9 +3221,23 @@ async def tunnel_connection( Tuple[SSHTCPChannel[bytes], SSHTCPSession[bytes]]: """Forward a local connection over SSH""" - return (await self.create_connection(session_factory, - dest_host, dest_port, - orig_host, orig_port)) + if accept_handler: + result = accept_handler(orig_host, orig_port) + + if inspect.isawaitable(result): + result = await cast(Awaitable[bool], result) + + if not result: + self.logger.info('Request for TCP forwarding from ' + '%s to %s denied by application', + (orig_host, orig_port), + (dest_host, dest_port)) + + raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, + 'Connection forwarding denied') + + return await self.create_connection(session_factory, dest_host, + dest_port, orig_host, orig_port) if (listen_host, listen_port) == (dest_host, dest_port): self.logger.info('Creating local TCP forwarder on %s', @@ -2938,6 +3259,9 @@ async def tunnel_connection( if listen_port == 0: listen_port = listener.get_port() + if dest_port == 0: + dest_port = listen_port + self._local_listeners[listen_host, listen_port] = listener return listener @@ -2989,6 +3313,22 @@ async def tunnel_connection( return listener + def forward_tuntap(self, mode: int, unit: Optional[int]) -> SSHForwarder: + """Set up TUN/TAP forwarding""" + + try: + transport, peer = create_tuntap(SSHForwarder, mode, unit) + interface = transport.get_extra_info('interface') + + self.logger.info(' Forwarding layer %d traffic to %s', + 3 if mode == SSH_TUN_MODE_POINTTOPOINT else 2, + interface) + except OSError as exc: + raise ChannelOpenError(OPEN_CONNECT_FAILED, str(exc)) from None + + return SSHForwarder(cast(SSHForwarder, peer), + extra={'interface': interface}) + def close_forward_listener(self, listen_key: ListenKey) -> None: """Mark a local forwarding listener as closed""" @@ -3026,6 +3366,12 @@ class SSHClientConnection(SSHConnection): UNIX domain socket forwarding can be set up by calling :meth:`forward_local_path` or :meth:`forward_remote_path`. + Mixed forwarding from a TCP port to a UNIX domain socket or + vice-versa can be set up by calling :meth:`forward_local_port_to_path`, + :meth:`forward_local_path_to_port`, + :meth:`forward_remote_port_to_path`, or + :meth:`forward_remote_path_to_port`. + """ _options: 'SSHClientConnectionOptions' @@ -3049,6 +3395,8 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._server_host_key_algs: Optional[Sequence[bytes]] = None self._server_host_key: Optional[SSHKey] = None + self._server_host_keys_handler = options.server_host_keys_handler + self._username = options.username self._password = options.password @@ -3056,6 +3404,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._client_keys: List[SSHKeyPair] = \ list(options.client_keys) if options.client_keys else [] + self._saved_rsa_key: Optional[_ClientHostKey] = None if options.preferred_auth != (): self._preferred_auth = [method.encode('ascii') for method in @@ -3080,7 +3429,8 @@ def __init__(self, loop: asyncio.AbstractEventLoop, if gss_host: try: - self._gss = GSSClient(gss_host, options.gss_delegate_creds) + self._gss = GSSClient(gss_host, options.gss_store, + options.gss_delegate_creds) self._gss_kex = options.gss_kex self._gss_auth = options.gss_auth self._gss_mic_auth = self._gss_auth @@ -3207,6 +3557,9 @@ def _choose_signature_alg(self, keypair: _ClientHostKey) -> bool: if self._server_sig_algs: for alg in keypair.sig_algorithms: + if keypair.use_webauthn and not alg.startswith(b'webauthn-'): + continue + if alg in self._sig_algs and alg in self._server_sig_algs: keypair.set_sig_algorithm(alg) return True @@ -3221,7 +3574,12 @@ def validate_server_host_key(self, key_data: bytes) -> SSHKey: self._host_key_alias or self._host, self._peer_addr, self._port, key_data) except ValueError as exc: - raise HostKeyNotVerifiable(str(exc)) from None + host = self._host + + if self._host_key_alias: + host += f' with alias {self._host_key_alias}' + + raise HostKeyNotVerifiable(f'{exc} for host {host}') from None self._server_host_key = host_key return host_key @@ -3252,24 +3610,29 @@ def get_server_auth_methods(self) -> Sequence[str]: return [method.decode('ascii') for method in self._auth_methods] - def try_next_auth(self) -> None: + def try_next_auth(self, *, next_method: bool = False) -> None: """Attempt client authentication using the next compatible method""" if self._auth: self._auth.cancel() self._auth = None - while self._auth_methods: - method = self._auth_methods.pop(0) + if next_method: + self._auth_methods.pop(0) - self._auth = lookup_client_auth(self, method) + while self._auth_methods: + self._auth = lookup_client_auth(self, self._auth_methods[0]) if self._auth: return + self._auth_methods.pop(0) + self.logger.info('Auth failed for user %s', self._username) - self._force_close(PermissionDenied('Permission denied')) + self._force_close(PermissionDenied('Permission denied for user ' + f'{self._username} on host ' + f'{self._host}')) def gss_kex_auth_requested(self) -> bool: """Return whether to allow GSS key exchange authentication or not""" @@ -3296,24 +3659,42 @@ async def host_based_auth_requested(self) -> \ if not self._host_based_auth: return None, '', '' + key: Optional[_ClientHostKey] + while True: - try: - key: Optional[_ClientHostKey] = self._client_host_keys.pop(0) - except IndexError: - key = None - break + if self._saved_rsa_key: + key = self._saved_rsa_key + key.algorithm = key.sig_algorithm + b'-cert-v01@openssh.com' + self._saved_rsa_key = None + else: + try: + key = self._client_host_keys.pop(0) + except IndexError: + key = None + break assert key is not None if self._choose_signature_alg(key): + if key.algorithm == b'ssh-rsa-cert-v01@openssh.com' and \ + key.sig_algorithm != b'ssh-rsa': + self._saved_rsa_key = key + break client_host = self._options.client_host if client_host is None: - client_host, _ = await self._loop.getnameinfo( - cast(SockAddr, self.get_extra_info('sockname')), - socket.NI_NUMERICSERV) + sockname = cast(SockAddr, self.get_extra_info('sockname')) + + if sockname: + try: + client_host, _ = await self._loop.getnameinfo( + sockname, socket.NI_NUMERICSERV) + except socket.gaierror: + client_host = sockname[0] + else: + client_host = '' # Add a trailing '.' to the client host to be compatible with # ssh-keysign from OpenSSH @@ -3340,6 +3721,8 @@ async def public_key_auth_requested(self) -> Optional[SSHKeyPair]: self._get_agent_keys = False if self._get_pkcs11_keys: + assert self._pkcs11_provider is not None + pkcs11_keys = await self._loop.run_in_executor( None, load_pkcs11_keys, self._pkcs11_provider, self._pkcs11_pin) @@ -3360,10 +3743,33 @@ async def public_key_auth_requested(self) -> Optional[SSHKeyPair]: self._client_keys = list(load_keypairs(result)) - keypair = self._client_keys.pop(0) + # OpenSSH versions before 7.8 didn't support RSA SHA-2 + # signature names in certificate key types, requiring the + # use of ssh-rsa-cert-v01@openssh.com as the key type even + # when using SHA-2 signatures. However, OpenSSL 8.8 and + # later reject ssh-rsa-cert-v01@openssh.com as a key type + # by default, requiring that the RSA SHA-2 version of the key + # type be used. This makes it difficult to use RSA keys with + # certificates without knowing the version of the remote + # server and which key types it will accept. + # + # The code below works around this by trying multiple key + # types during public key and host-based authentication when + # using SHA-2 signatures with RSA keys signed by certificates. + + if self._saved_rsa_key: + key = self._saved_rsa_key + key.algorithm = key.sig_algorithm + b'-cert-v01@openssh.com' + self._saved_rsa_key = None + else: + key = self._client_keys.pop(0) + + if self._choose_signature_alg(key): + if key.algorithm == b'ssh-rsa-cert-v01@openssh.com' and \ + key.sig_algorithm != b'ssh-rsa': + self._saved_rsa_key = key - if self._choose_signature_alg(keypair): - return keypair + return key async def password_auth_requested(self) -> Optional[str]: """Return a password to authenticate with""" @@ -3373,6 +3779,15 @@ async def password_auth_requested(self) -> Optional[str]: if self._password is not None: password: Optional[str] = self._password + + if callable(password): + password = cast(Callable[[], Optional[str]], password)() + + if inspect.isawaitable(password): + password = await cast(Awaitable[Optional[str]], password) + else: + password = cast(Optional[str], password) + self._password = None else: result = self._owner.password_auth_requested() @@ -3556,6 +3971,20 @@ def _process_direct_streamlocal_at_openssh_dot_com_open( 'Direct UNIX domain socket open ' 'forbidden on client') + def _process_tun_at_openssh_dot_com_open( + self, _packet: SSHPacket) -> \ + Tuple[SSHTunTapChannel, SSHTunTapSession]: + """Process an inbound TUN/TAP open request + + These requests are disallowed on an SSH client. + + """ + + # pylint: disable=no-self-use + + raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, + 'TUN/TAP request forbidden on client') + def _process_forwarded_streamlocal_at_openssh_dot_com_open( self, packet: SSHPacket) -> \ Tuple[SSHUNIXChannel, MaybeAwait[SSHUNIXSession]]: @@ -3638,9 +4067,82 @@ def _process_auth_agent_at_openssh_dot_com_open( raise ChannelOpenError(OPEN_CONNECT_FAILED, 'Auth agent forwarding disabled') - async def attach_x11_listener(self, chan: SSHClientChannel[AnyStr], - display: Optional[str], - auth_path: Optional[str], + def _process_hostkeys_00_at_openssh_dot_com_global_request( + self, packet: SSHPacket) -> None: + """Process a list of accepted server host keys""" + + self.create_task(self._finish_hostkeys(packet)) + + async def _finish_hostkeys(self, packet: SSHPacket) -> None: + """Finish processing hostkeys global request""" + + if not self._server_host_keys_handler: + self.logger.debug1('Ignoring server host key message: no handler') + self._report_global_response(False) + return + + if self._trusted_host_keys is None: + self.logger.info('Server host key not verified: handler disabled') + self._report_global_response(False) + return + + added = [] + removed = list(self._trusted_host_keys) + retained = [] + revoked = [] + prove = [] + + while packet: + try: + key_data = packet.get_string() + key = decode_ssh_public_key(key_data) + + if key in self._revoked_host_keys: + revoked.append(key) + elif key in self._trusted_host_keys: + retained.append(key) + removed.remove(key) + else: + prove.append((key, String(key_data))) + except KeyImportError: + pass + + if prove: + pkttype, packet = await self._make_global_request( + b'hostkeys-prove-00@openssh.com', + b''.join(key_str for _, key_str in prove)) + + if pkttype == MSG_REQUEST_SUCCESS: + prefix = String('hostkeys-prove-00@openssh.com') + \ + String(self._session_id) + + for key, key_str in prove: + sig = packet.get_string() + + if key.verify(prefix + key_str, sig): + added.append(key) + else: + self.logger.debug1('Server host key validation failed') + else: + self.logger.debug1('Server host key prove request failed') + + packet.check_end() + + self.logger.info(f'Server host key report: {len(added)} added, ' + f'{len(removed)} removed, {len(retained)} retained, ' + f'{len(revoked)} revoked') + + result = self._server_host_keys_handler(added, removed, + retained, revoked) + + if inspect.isawaitable(result): + await result + + self._report_global_response(True) + + async def attach_x11_listener(self, chan: SSHClientChannel[AnyStr], + display: Optional[str], + auth_path: Optional[str], single_connection: bool) -> \ Tuple[bytes, bytes, int]: """Attach a channel to a local X11 display""" @@ -3667,8 +4169,8 @@ def detach_x11_listener(self, chan: SSHChannel[AnyStr]) -> None: async def create_session(self, session_factory: SSHClientSessionFactory, command: DefTuple[Optional[str]] = (), *, subsystem: DefTuple[Optional[str]]= (), - env: DefTuple[_Env] = (), - send_env: DefTuple[_SendEnv] = (), + env: DefTuple[Optional[Env]] = (), + send_env: DefTuple[Optional[EnvSeq]] = (), request_pty: DefTuple[Union[bool, str]] = (), term_type: DefTuple[Optional[str]] = (), term_size: DefTuple[TermSizeArg] = (), @@ -3773,8 +4275,8 @@ async def create_session(self, session_factory: SSHClientSessionFactory, :type session_factory: `callable` :type command: `str` :type subsystem: `str` - :type env: `dict` with `str` keys and values - :type send_env: `list` of `str` + :type env: `dict` with `bytes` or `str` keys and values + :type send_env: `list` of `bytes` or `str` :type request_pty: `bool`, `'force'`, or `'auto'` :type term_type: `str` :type term_size: `tuple` of 2 or 4 `int` values @@ -3842,22 +4344,13 @@ async def create_session(self, session_factory: SSHClientSessionFactory, if max_pktsize == (): max_pktsize = self._options.max_pktsize - new_env: Dict[str, str] = {} + new_env: Dict[bytes, bytes] = {} if send_env: - for key in send_env: - pattern = WildcardPattern(key) - new_env.update((key, value) for key, value in os.environ.items() - if pattern.matches(key)) + new_env.update(lookup_env(send_env)) if env: - try: - if isinstance(env, list): - new_env.update((item.split('=', 2) for item in env)) - else: - new_env.update(cast(Mapping[str, str], env)) - except ValueError: - raise ValueError('Invalid environment value') from None + new_env.update(encode_env(env)) if request_pty == 'force': request_pty = True @@ -3918,14 +4411,15 @@ async def open_session(self, *args: object, **kwargs: object) -> \ SSHReader(session, chan, EXTENDED_DATA_STDERR)) # pylint: disable=redefined-builtin - @async_context_manager + @async_context_manager # type: ignore async def create_process(self, *args: object, - bufsize: int = io.DEFAULT_BUFFER_SIZE, input: Optional[AnyStr] = None, stdin: ProcessSource = PIPE, stdout: ProcessTarget = PIPE, stderr: ProcessTarget = PIPE, - **kwargs: object) -> SSHClientProcess: + bufsize: int = io.DEFAULT_BUFFER_SIZE, + send_eof: bool = True, recv_eof: bool = True, + **kwargs: object) -> SSHClientProcess[AnyStr]: """Create a process on the remote system This method is a coroutine wrapper around :meth:`create_session` @@ -3949,8 +4443,6 @@ async def create_process(self, *args: object, :meth:`create_session` except for `session_factory` are supported and have the same meaning. - :param bufsize: (optional) - Buffer size to use when feeding data from a file to stdin :param input: (optional) Input data to feed to standard input of the remote process. If specified, this argument takes precedence over stdin. @@ -3968,8 +4460,23 @@ async def create_process(self, *args: object, :class:`SSHWriter` to feed standard error of the remote process to, `DEVNULL` to discard this output, or `STDOUT` to feed standard error to the same place as stdout. - :type bufsize: `int` + :param bufsize: (optional) + Buffer size to use when feeding data from a file to stdin + :param send_eof: + Whether or not to send EOF to the channel when EOF is + received from stdin, defaulting to `True`. If set to `False`, + the channel will remain open after EOF is received on stdin, + and multiple sources can be redirected to the channel. + :param recv_eof: + Whether or not to send EOF to stdout and stderr when EOF is + received from the channel, defaulting to `True`. If set to + `False`, the redirect targets of stdout and stderr will remain + open after EOF is received on the channel and can be used for + multiple redirects. :type input: `str` or `bytes` + :type bufsize: `int` + :type send_eof: `bool` + :type recv_eof: `bool` :returns: :class:`SSHClientProcess` @@ -3988,7 +4495,8 @@ async def create_process(self, *args: object, chan.write_eof() new_stdin = None - await process.redirect(new_stdin, stdout, stderr, bufsize) + await process.redirect(new_stdin, stdout, stderr, + bufsize, send_eof, recv_eof) return process @@ -4207,9 +4715,10 @@ async def open_connection(self, *args: object, **kwargs: object) -> \ @async_context_manager async def create_server( - self, session_factory: TCPListenerFactory, listen_host: str, - listen_port: int, *, encoding: Optional[str] = None, - errors: str = 'strict', window: int = _DEFAULT_WINDOW, + self, session_factory: TCPListenerFactory[AnyStr], + listen_host: str, listen_port: int, *, + encoding: Optional[str] = None, errors: str = 'strict', + window: int = _DEFAULT_WINDOW, max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> SSHListener: """Create a remote SSH TCP listener @@ -4239,7 +4748,7 @@ async def create_server( The receive window size for this session :param max_pktsize: (optional) The maximum packet size for this session - :type session_factory: `callable` + :type session_factory: `callable` or coroutine :type listen_host: `str` :type listen_port: `int` :type encoding: `str` or `None` @@ -4548,8 +5057,8 @@ async def create_ssh_connection(self, client_factory: _ClientFactory, """ - return (await create_connection(client_factory, host, port, - tunnel=self, **kwargs)) # type: ignore + return await create_connection(client_factory, host, port, + tunnel=self, **kwargs) # type: ignore @async_context_manager async def connect_ssh(self, host: str, port: DefTuple[int] = (), @@ -4614,6 +5123,281 @@ async def listen_reverse_ssh(self, host: str = '', return await listen_reverse(host, port, tunnel=self, **kwargs) # type: ignore + async def create_tun( + self, session_factory: SSHTunTapSessionFactory, + remote_unit: Optional[int] = None, *, window: int = _DEFAULT_WINDOW, + max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ + Tuple[SSHTunTapChannel, SSHTunTapSession]: + """Create an SSH layer 3 tunnel + + This method is a coroutine which can be called to request that + the server open a new outbound layer 3 tunnel to the specified + remote TUN device. If the tunnel is successfully opened, a new + SSH channel will be opened with data being handled by a + :class:`SSHTunTapSession` object created by `session_factory`. + + Optional arguments include the SSH receive window size and max + packet size which default to 2 MB and 32 KB, respectively. + + :param session_factory: + A `callable` which returns an :class:`SSHUNIXSession` object + that will be created to handle activity on this session + :param remote_unit: + The remote TUN device to connect to + :param window: (optional) + The receive window size for this session + :param max_pktsize: (optional) + The maximum packet size for this session + :type session_factory: `callable` + :type remote_unit: `int` or `None` + :type window: `int` + :type max_pktsize: `int` + + :returns: an :class:`SSHTunTapChannel` and :class:`SSHTunTapSession` + + :raises: :exc:`ChannelOpenError` if the connection can't be opened + + """ + + self.logger.info('Opening layer 3 tunnel to remote unit %s', + 'any' if remote_unit is None else str(remote_unit)) + + chan = self.create_tuntap_channel(window, max_pktsize) + + session = await chan.open(session_factory, SSH_TUN_MODE_POINTTOPOINT, + remote_unit) + + return chan, session + + async def create_tap( + self, session_factory: SSHTunTapSessionFactory, + remote_unit: Optional[int] = None, *, window: int = _DEFAULT_WINDOW, + max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> \ + Tuple[SSHTunTapChannel, SSHTunTapSession]: + """Create an SSH layer 2 tunnel + + This method is a coroutine which can be called to request that + the server open a new outbound layer 2 tunnel to the specified + remote TAP device. If the tunnel is successfully opened, a new + SSH channel will be opened with data being handled by a + :class:`SSHTunTapSession` object created by `session_factory`. + + Optional arguments include the SSH receive window size and max + packet size which default to 2 MB and 32 KB, respectively. + + :param session_factory: + A `callable` which returns an :class:`SSHUNIXSession` object + that will be created to handle activity on this session + :param remote_unit: + The remote TAP device to connect to + :param window: (optional) + The receive window size for this session + :param max_pktsize: (optional) + The maximum packet size for this session + :type session_factory: `callable` + :type remote_unit: `int` or `None` + :type window: `int` + :type max_pktsize: `int` + + :returns: an :class:`SSHTunTapChannel` and :class:`SSHTunTapSession` + + :raises: :exc:`ChannelOpenError` if the connection can't be opened + + """ + + self.logger.info('Opening layer 2 tunnel to remote unit %s', + 'any' if remote_unit is None else str(remote_unit)) + + chan = self.create_tuntap_channel(window, max_pktsize) + + session = await chan.open(session_factory, SSH_TUN_MODE_ETHERNET, + remote_unit) + + return chan, session + + async def open_tun(self, *args: object, **kwargs: object) -> \ + Tuple[SSHReader, SSHWriter]: + """Open an SSH layer 3 tunnel + + This method is a coroutine wrapper around :meth:`create_tun` + designed to provide a "high-level" stream interface for creating + an SSH layer 3 tunnel. Instead of taking a `session_factory` + argument for constructing an object which will handle activity + on the session via callbacks, it returns :class:`SSHReader` and + :class:`SSHWriter` objects which can be used to perform I/O on + the tunnel. + + With the exception of `session_factory`, all of the arguments + to :meth:`create_tun` are supported and have the same meaning here. + + :returns: an :class:`SSHReader` and :class:`SSHWriter` + + :raises: :exc:`ChannelOpenError` if the connection can't be opened + + """ + + chan, session = await self.create_tun(SSHTunTapStreamSession, + *args, **kwargs) # type: ignore + + session: SSHTunTapStreamSession + + return SSHReader(session, chan), SSHWriter(session, chan) + + async def open_tap(self, *args: object, **kwargs: object) -> \ + Tuple[SSHReader, SSHWriter]: + """Open an SSH layer 2 tunnel + + This method is a coroutine wrapper around :meth:`create_tap` + designed to provide a "high-level" stream interface for creating + an SSH layer 2 tunnel. Instead of taking a `session_factory` + argument for constructing an object which will handle activity + on the session via callbacks, it returns :class:`SSHReader` and + :class:`SSHWriter` objects which can be used to perform I/O on + the tunnel. + + With the exception of `session_factory`, all of the arguments + to :meth:`create_tap` are supported and have the same meaning here. + + :returns: an :class:`SSHReader` and :class:`SSHWriter` + + :raises: :exc:`ChannelOpenError` if the connection can't be opened + + """ + + chan, session = await self.create_tap(SSHTunTapStreamSession, + *args, **kwargs) # type: ignore + + session: SSHTunTapStreamSession + + return SSHReader(session, chan), SSHWriter(session, chan) + + @async_context_manager + async def forward_local_port_to_path( + self, listen_host: str, listen_port: int, dest_path: str, + accept_handler: Optional[SSHAcceptHandler] = None) -> SSHListener: + """Set up local TCP port forwarding to a remote UNIX domain socket + + This method is a coroutine which attempts to set up port + forwarding from a local TCP listening port to a remote UNIX + domain path via the SSH connection. If the request is successful, + the return value is an :class:`SSHListener` object which can be + used later to shut down the port forwarding. + + :param listen_host: + The hostname or address on the local host to listen on + :param listen_port: + The port number on the local host to listen on + :param dest_path: + The path on the remote host to forward the connections to + :param accept_handler: + A `callable` or coroutine which takes arguments of the + original host and port of the client and decides whether + or not to allow connection forwarding, returning `True` to + accept the connection and begin forwarding or `False` to + reject and close it. + :type listen_host: `str` + :type listen_port: `int` + :type dest_path: `str` + :type accept_handler: `callable` or coroutine + + :returns: :class:`SSHListener` + + :raises: :exc:`OSError` if the listener can't be opened + + """ + + async def tunnel_connection( + session_factory: SSHUNIXSessionFactory[bytes], + orig_host: str, orig_port: int) -> \ + Tuple[SSHUNIXChannel[bytes], SSHUNIXSession[bytes]]: + """Forward a local connection over SSH""" + + if accept_handler: + result = accept_handler(orig_host, orig_port) + + if inspect.isawaitable(result): + result = await cast(Awaitable[bool], result) + + if not result: + self.logger.info('Request for TCP forwarding from ' + '%s to %s denied by application', + (orig_host, orig_port), dest_path) + + raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, + 'Connection forwarding denied') + + return await self.create_unix_connection(session_factory, dest_path) + + self.logger.info('Creating local TCP forwarder from %s to %s', + (listen_host, listen_port), dest_path) + + try: + listener = await create_tcp_forward_listener(self, self._loop, + tunnel_connection, + listen_host, + listen_port) + except OSError as exc: + self.logger.debug1('Failed to create local TCP listener: %s', exc) + raise + + if listen_port == 0: + listen_port = listener.get_port() + + self._local_listeners[listen_host, listen_port] = listener + + return listener + + @async_context_manager + async def forward_local_path_to_port(self, listen_path: str, + dest_host: str, + dest_port: int) -> SSHListener: + """Set up local UNIX domain socket forwarding to a remote TCP port + + This method is a coroutine which attempts to set up UNIX domain + socket forwarding from a local listening path to a remote host + and port via the SSH connection. If the request is successful, + the return value is an :class:`SSHListener` object which can + be used later to shut down the UNIX domain socket forwarding. + + :param listen_path: + The path on the local host to listen on + :param dest_host: + The hostname or address to forward the connections to + :param dest_port: + The port number to forward the connections to + :type listen_path: `str` + :type dest_host: `str` + :type dest_port: `int` + + :returns: :class:`SSHListener` + + :raises: :exc:`OSError` if the listener can't be opened + + """ + + async def tunnel_connection( + session_factory: SSHTCPSessionFactory[bytes]) -> \ + Tuple[SSHTCPChannel[bytes], SSHTCPSession[bytes]]: + """Forward a local connection over SSH""" + + return await self.create_connection(session_factory, dest_host, + dest_port, '', 0) + + self.logger.info('Creating local UNIX forwarder from %s to %s', + listen_path, (dest_host, dest_port)) + + try: + listener = await create_unix_forward_listener(self, self._loop, + tunnel_connection, + listen_path) + except OSError as exc: + self.logger.debug1('Failed to create local UNIX listener: %s', exc) + raise + + self._local_listeners[listen_path] = listener + + return listener + @async_context_manager async def forward_remote_port(self, listen_host: str, listen_port: int, dest_host: str, @@ -4647,36 +5431,118 @@ async def forward_remote_port(self, listen_host: str, """ def session_factory(_orig_host: str, - _orig_port: int) -> Awaitable[SSHTCPSession]: + _orig_port: int) -> Awaitable[SSHTCPSession]: + """Return an SSHTCPSession used to do remote port forwarding""" + + return cast(Awaitable[SSHTCPSession], + self.forward_connection(dest_host, dest_port)) + + self.logger.info('Creating remote TCP forwarder from %s to %s', + (listen_host, listen_port), (dest_host, dest_port)) + + return await self.create_server(session_factory, listen_host, + listen_port) + + @async_context_manager + async def forward_remote_path(self, listen_path: str, + dest_path: str) -> SSHListener: + """Set up remote UNIX domain socket forwarding + + This method is a coroutine which attempts to set up UNIX domain + socket forwarding from a remote listening path to a local path + via the SSH connection. If the request is successful, the + return value is an :class:`SSHListener` object which can be + used later to shut down the port forwarding. If the request + fails, `None` is returned. + + :param listen_path: + The path on the remote host to listen on + :param dest_path: + The path on the local host to forward connections to + :type listen_path: `str` + :type dest_path: `str` + + :returns: :class:`SSHListener` + + :raises: :class:`ChannelListenError` if the listener can't be opened + + """ + + def session_factory() -> Awaitable[SSHUNIXSession[bytes]]: + """Return an SSHUNIXSession used to do remote path forwarding""" + + return cast(Awaitable[SSHUNIXSession[bytes]], + self.forward_unix_connection(dest_path)) + + self.logger.info('Creating remote UNIX forwarder from %s to %s', + listen_path, dest_path) + + return await self.create_unix_server(session_factory, listen_path) + + @async_context_manager + async def forward_remote_port_to_path(self, listen_host: str, + listen_port: int, + dest_path: str) -> SSHListener: + """Set up remote TCP port forwarding to a local UNIX domain socket + + This method is a coroutine which attempts to set up port + forwarding from a remote TCP listening port to a local UNIX + domain socket path via the SSH connection. If the request is + successful, the return value is an :class:`SSHListener` object + which can be used later to shut down the port forwarding. If + the request fails, `None` is returned. + + :param listen_host: + The hostname or address on the remote host to listen on + :param listen_port: + The port number on the remote host to listen on + :param dest_path: + The path on the local host to forward connections to + :type listen_host: `str` + :type listen_port: `int` + :type dest_path: `str` + + :returns: :class:`SSHListener` + + :raises: :class:`ChannelListenError` if the listener can't be opened + + """ + + def session_factory(_orig_host: str, + _orig_port: int) -> Awaitable[SSHUNIXSession]: """Return an SSHTCPSession used to do remote port forwarding""" - return cast(Awaitable[SSHTCPSession], - self.forward_connection(dest_host, dest_port)) + return cast(Awaitable[SSHUNIXSession], + self.forward_unix_connection(dest_path)) self.logger.info('Creating remote TCP forwarder from %s to %s', - (listen_host, listen_port), (dest_host, dest_port)) + (listen_host, listen_port), dest_path) return await self.create_server(session_factory, listen_host, listen_port) @async_context_manager - async def forward_remote_path(self, listen_path: str, - dest_path: str) -> SSHListener: - """Set up remote UNIX domain socket forwarding + async def forward_remote_path_to_port(self, listen_path: str, + dest_host: str, + dest_port: int) -> SSHListener: + """Set up remote UNIX domain socket forwarding to a local TCP port This method is a coroutine which attempts to set up UNIX domain - socket forwarding from a remote listening path to a local path - via the SSH connection. If the request is successful, the - return value is an :class:`SSHListener` object which can be - used later to shut down the port forwarding. If the request - fails, `None` is returned. + socket forwarding from a remote listening path to a local TCP + host and port via the SSH connection. If the request is + successful, the return value is an :class:`SSHListener` object + which can be used later to shut down the port forwarding. If + the request fails, `None` is returned. :param listen_path: The path on the remote host to listen on - :param dest_path: - The path on the local host to forward connections to + :param dest_host: + The hostname or address to forward connections to + :param dest_port: + The port number to forward connections to :type listen_path: `str` - :type dest_path: `str` + :type dest_host: `str` + :type dest_port: `int` :returns: :class:`SSHListener` @@ -4684,14 +5550,14 @@ async def forward_remote_path(self, listen_path: str, """ - def session_factory() -> Awaitable[SSHUNIXSession[bytes]]: + def session_factory() -> Awaitable[SSHTCPSession[bytes]]: """Return an SSHUNIXSession used to do remote path forwarding""" - return cast(Awaitable[SSHUNIXSession[bytes]], - self.forward_unix_connection(dest_path)) + return cast(Awaitable[SSHTCPSession[bytes]], + self.forward_connection(dest_host, dest_port)) self.logger.info('Creating remote UNIX forwarder from %s to %s', - listen_path, dest_path) + listen_path, (dest_host, dest_port)) return await self.create_unix_server(session_factory, listen_path) @@ -4752,8 +5618,77 @@ async def tunnel_socks(session_factory: SSHTCPSessionFactory[bytes], return listener @async_context_manager - async def start_sftp_client(self, env: DefTuple[_Env] = (), - send_env: DefTuple[_SendEnv] = (), + async def forward_tun(self, local_unit: Optional[int] = None, + remote_unit: Optional[int] = None) -> SSHForwarder: + """Set up layer 3 forwarding + + This method is a coroutine which attempts to set up layer 3 + packet forwarding between local and remote TUN devices. If the + request is successful, the return value is an :class:`SSHForwarder` + object which can be used later to shut down the forwarding. + + :param local_unit: + The unit number of the local TUN device to use + :param remote_unit: + The unit number of the remote TUN device to use + :type local_unit: `int` or `None` + :type remote_unit: `int` or `None` + + :returns: :class:`SSHForwarder` + + :raises: | :exc:`OSError` if the local TUN device can't be opened + | :exc:`ChannelOpenError` if the SSH channel can't be opened + + """ + + def session_factory() -> SSHTunTapSession: + """Return an SSHTunTapSession used to do layer 3 forwarding""" + + return cast(SSHTunTapSession, + self.forward_tuntap(SSH_TUN_MODE_POINTTOPOINT, + local_unit)) + + _, peer = await self.create_tun(session_factory, remote_unit) + + return cast(SSHForwarder, peer) + + @async_context_manager + async def forward_tap(self, local_unit: Optional[int] = None, + remote_unit: Optional[int] = None) -> SSHForwarder: + """Set up layer 2 forwarding + + This method is a coroutine which attempts to set up layer 2 + packet forwarding between local and remote TAP devices. If the + request is successful, the return value is an :class:`SSHForwarder` + object which can be used later to shut down the forwarding. + + :param local_unit: + The unit number of the local TAP device to use + :param remote_unit: + The unit number of the remote TAP device to use + :type local_unit: `int` or `None` + :type remote_unit: `int` or `None` + + :returns: :class:`SSHForwarder` + + :raises: | :exc:`OSError` if the local TUN device can't be opened + | :exc:`ChannelOpenError` if the SSH channel can't be opened + + """ + + def session_factory() -> SSHTunTapSession: + """Return an SSHTunTapSession used to do layer 2 forwarding""" + + return cast(SSHTunTapSession, + self.forward_tuntap(SSH_TUN_MODE_ETHERNET, local_unit)) + + _, peer = await self.create_tap(session_factory, remote_unit) + + return cast(SSHForwarder, peer) + + @async_context_manager + async def start_sftp_client(self, env: DefTuple[Optional[Env]] = (), + send_env: DefTuple[Optional[EnvSeq]] = (), path_encoding: Optional[str] = 'utf-8', path_errors = 'strict', sftp_version = MIN_SFTP_VERSION) -> SFTPClient: @@ -4865,6 +5800,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._options = options self._server_host_keys = options.server_host_keys + self._all_server_host_keys = options.all_server_host_keys self._server_host_key_algs = list(options.server_host_keys.keys()) self._known_client_hosts = options.known_client_hosts self._trust_client_host = options.trust_client_host @@ -4890,7 +5826,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop, if options.gss_host: try: - self._gss = GSSServer(options.gss_host) + self._gss = GSSServer(options.gss_host, options.gss_store) self._gss_kex = options.gss_kex self._gss_auth = options.gss_auth self._gss_mic_auth = self._gss_auth @@ -4935,11 +5871,10 @@ async def reload_config(self) -> None: self._peer_host, _ = await self._loop.getnameinfo( (self._peer_addr, self._peer_port), socket.NI_NUMERICSERV) - options = cast(SSHServerConnectionOptions, await _run_in_executor( - self._loop, SSHServerConnectionOptions, options=self._options, - reload=True, accept_addr=self._local_addr, + options = await SSHServerConnectionOptions.construct( + options=self._options, reload=True, accept_addr=self._local_addr, accept_port=self._local_port, username=self._username, - client_host=self._peer_host, client_addr=self._peer_addr)) + client_host=self._peer_host, client_addr=self._peer_addr) self._options = options @@ -4960,7 +5895,7 @@ async def reload_config(self) -> None: self._keepalive_interval = options.keepalive_interval def choose_server_host_key(self, - peer_host_key_algs: Sequence[bytes]) -> bool: + peer_host_key_algs: Sequence[bytes]) -> bool: """Choose the server host key to use Given a list of host key algorithms supported by the client, @@ -4991,6 +5926,17 @@ def get_server_host_key(self) -> Optional[SSHKeyPair]: return self._server_host_key + def send_server_host_keys(self) -> None: + """Send list of available server host keys""" + + if self._all_server_host_keys: + self.logger.info('Sending server host keys') + + keys = [String(key) for key in self._all_server_host_keys.keys()] + self._send_global_request(b'hostkeys-00@openssh.com', *keys) + else: + self.logger.info('Sending server host keys disabled') + def gss_kex_auth_supported(self) -> bool: """Return whether GSS key exchange authentication is supported""" @@ -5041,9 +5987,13 @@ async def validate_host_based_auth(self, username: str, key_data: bytes, if self._trust_client_host: resolved_host = client_host else: - resolved_host, _ = await self._loop.getnameinfo( - cast(SockAddr, self.get_extra_info('peername')), - socket.NI_NUMERICSERV) + peername = cast(SockAddr, self.get_extra_info('peername')) + + try: + resolved_host, _ = await self._loop.getnameinfo( + peername, socket.NI_NUMERICSERV) + except socket.gaierror: + resolved_host = peername[0] if resolved_host != client_host: self.logger.info('Client host mismatch: received %s, ' @@ -5097,11 +6047,10 @@ async def _validate_openssh_certificate( self._key_options = options - if self.get_key_option('principals'): - username = '' + cert_user = None if self.get_key_option('principals') else username try: - cert.validate(CERT_TYPE_USER, username) + cert.validate(CERT_TYPE_USER, cert_user) except ValueError: return None @@ -5356,6 +6305,9 @@ def _process_session_open(self, packet: SSHPacket) -> \ if not result: raise ChannelOpenError(OPEN_CONNECT_FAILED, 'Session refused') + if isinstance(result, SSHClientConnection): + result = self.forward_tunneled_session(result) + if isinstance(result, tuple): chan, result = result else: @@ -5399,8 +6351,8 @@ def _process_direct_tcpip_open(self, packet: SSHPacket) -> \ (dest_host, dest_port) not in permitted_opens and \ (dest_host, None) not in permitted_opens: raise ChannelOpenError(OPEN_ADMINISTRATIVELY_PROHIBITED, - 'Port forwarding not permitted to %s ' - 'port %s' % (dest_host, dest_port)) + 'Port forwarding not permitted to ' + f'{dest_host} port {dest_port}') result = self._owner.connection_requested(dest_host, dest_port, orig_host, orig_port) @@ -5411,6 +6363,10 @@ def _process_direct_tcpip_open(self, packet: SSHPacket) -> \ if result is True: result = cast(SSHTCPSession[bytes], self.forward_connection(dest_host, dest_port)) + elif isinstance(result, SSHClientConnection): + result = cast(Awaitable[SSHTCPSession[bytes]], + self.forward_tunneled_connection( + result, dest_host, dest_port)) if isinstance(result, tuple): chan, result = result @@ -5468,6 +6424,10 @@ async def _finish_port_forward(self, listen_host: str, if listener is True: listener = await self.forward_local_port( listen_host, listen_port, listen_host, listen_port) + elif callable(listener): + listener = await self.forward_local_port( + listen_host, listen_port, + listen_host, listen_port, listener) except OSError: self.logger.debug1('Failed to create TCP listener') self._report_global_response(False) @@ -5553,6 +6513,10 @@ def _process_direct_streamlocal_at_openssh_dot_com_open( if result is True: result = cast(SSHUNIXSession[bytes], self.forward_unix_connection(dest_path)) + elif isinstance(result, SSHClientConnection): + result = cast(Awaitable[SSHUNIXSession[bytes]], + self.forward_tunneled_unix_connection( + result, dest_path)) if isinstance(result, tuple): chan, result = result @@ -5648,6 +6612,75 @@ def _process_cancel_streamlocal_forward_at_openssh_dot_com_global_request( self._report_global_response(True) + def _process_tun_at_openssh_dot_com_open( + self, packet: SSHPacket) -> \ + Tuple[SSHTunTapChannel, SSHTunTapSession]: + """Process an incoming TUN/TAP open request""" + + mode = packet.get_uint32() + unit: Optional[int] = packet.get_uint32() + packet.check_end() + + if unit == SSH_TUN_UNIT_ANY: + unit = None + + if mode == SSH_TUN_MODE_POINTTOPOINT: + result = self._owner.tun_requested(unit) + elif mode == SSH_TUN_MODE_ETHERNET: + result = self._owner.tap_requested(unit) + else: + result = False + + if not result: + raise ChannelOpenError(OPEN_CONNECT_FAILED, + 'TUN/TAP request refused') + + if result is True: + result = cast(SSHTunTapSession, self.forward_tuntap(mode, unit)) + elif isinstance(result, SSHClientConnection): + result = cast(Awaitable[SSHTunTapSession], + self.forward_tunneled_tuntap(result, mode, unit)) + + if isinstance(result, tuple): + chan, result = result + else: + chan = self.create_tuntap_channel() + + session: SSHTunTapSession + + if callable(result): + session = SSHTunTapStreamSession(result) + else: + session = cast(SSHTunTapSession, result) + + self.logger.info('Accepted layer %d tunnel request to unit %s', + 3 if mode == SSH_TUN_MODE_POINTTOPOINT else 2, + 'any' if unit == SSH_TUN_UNIT_ANY else str(unit)) + + chan.set_mode(mode) + + return chan, session + + def _process_hostkeys_prove_00_at_openssh_dot_com_global_request( + self, packet: SSHPacket) -> None: + """Prove the server has private keys for all requested host keys""" + + prefix = String('hostkeys-prove-00@openssh.com') + \ + String(self._session_id) + + signatures = [] + + while packet: + try: + key_data = packet.get_string() + key = self._all_server_host_keys[key_data] + signatures.append(String(key.sign(prefix + String(key_data)))) + except (KeyError, KeyImportError): + self._report_global_response(False) + return + + self._report_global_response(b''.join(signatures)) + async def attach_x11_listener(self, chan: SSHServerChannel[AnyStr], auth_proto: bytes, auth_data: bytes, screen: int) -> Optional[str]: @@ -6161,8 +7194,78 @@ async def open_agent_connection(self) -> \ return SSHReader[bytes](session, chan), SSHWriter[bytes](session, chan) + async def forward_tunneled_session( + self, conn: SSHClientConnection) -> SSHServerProcess: + """Forward a tunneled session between SSH connections""" + + async def process_factory(process: SSHServerProcess) -> None: + """Return an upstream process used to forward the session""" -class SSHConnectionOptions(Options): + encoding, errors = process.channel.get_encoding() + + upstream_process: SSHClientProcess = await conn.create_process( + command=process.command, subsystem=process.subsystem, + env=process.env, term_type=process.term_type, + term_size=process.term_size, term_modes=process.term_modes, + encoding=encoding, errors=errors, stdin=process.stdin, + stdout=process.stdout, stderr=process.stderr) + + await upstream_process.wait_closed() + + self.logger.info(' Forwarding session via SSH tunnel') + + return SSHServerProcess(process_factory, None, MIN_SFTP_VERSION, False) + + async def forward_tunneled_connection( + self, conn: SSHClientConnection, + dest_host: str, dest_port: int) -> SSHForwarder: + """Forward a tunneled TCP connection between SSH connections""" + + _, peer = await conn.create_connection( + cast(SSHTCPSessionFactory[bytes], SSHForwarder), + dest_host, dest_port) + + self.logger.info(' Forwarding TCP connection to %s via SSH tunnel', + (dest_host, dest_port)) + + return SSHForwarder(cast(SSHForwarder, peer)) + + async def forward_tunneled_unix_connection( + self, conn: SSHClientConnection, + dest_path: str) -> SSHForwarder: + """Forward a tunneled UNIX connection between SSH connections""" + + _, peer = await conn.create_unix_connection( + cast(SSHUNIXSessionFactory[bytes], SSHForwarder), dest_path) + + self.logger.info(' Forwarding UNIX connection to %s via SSH tunnel', + dest_path) + + return SSHForwarder(cast(SSHForwarder, peer)) + + async def forward_tunneled_tuntap( + self, conn: SSHClientConnection, + mode: int, unit: Optional[int]) -> SSHForwarder: + """Forward a TUN/TAP connection between SSH connections""" + + if mode == SSH_TUN_MODE_POINTTOPOINT: + create_func = conn.create_tun + layer = 3 + else: + create_func = conn.create_tap + layer = 2 + + transport, peer = await create_func( + cast(SSHTunTapSessionFactory, SSHForwarder), unit) + interface = transport.get_extra_info('interface') + + self.logger.info(' Forwarding layer %d traffic to %s via SSH tunnel', + layer, interface) + + return SSHForwarder(cast(SSHForwarder, peer)) + + +class SSHConnectionOptions(Options, Generic[_Options]): """SSH connection options""" config: SSHConfig @@ -6176,6 +7279,11 @@ class SSHConnectionOptions(Options): family: int local_addr: HostPort tcp_keepalive: bool + canonicalize_hostname: Union[bool, str] + canonical_domains: Sequence[str] + canonicalize_fallback_local: bool + canonicalize_max_dots: int + canonicalize_permitted_cnames: Sequence[Tuple[str, str]] kex_algs: Sequence[bytes] encryption_algs: Sequence[bytes] mac_algs: Sequence[bytes] @@ -6195,11 +7303,20 @@ class SSHConnectionOptions(Options): keepalive_internal: float keepalive_count_max: int - def __init__(self, options: Optional['SSHConnectionOptions'] = None, - **kwargs: object): + def __init__(self, options: Optional[_Options] = None, **kwargs: object): last_config = options.config if options else None super().__init__(options=options, last_config=last_config, **kwargs) + @classmethod + async def construct(cls, options: Optional[_Options] = None, + **kwargs: object) -> _Options: + """Construct a new options object from within an async task""" + + loop = asyncio.get_event_loop() + + return cast(_Options, await loop.run_in_executor( + None, functools.partial(cls, options, loop=loop, **kwargs))) + # pylint: disable=arguments-differ def prepare(self, config: SSHConfig, # type: ignore protocol_factory: _ProtocolFactory, version: _VersionArg, @@ -6207,6 +7324,11 @@ def prepare(self, config: SSHConfig, # type: ignore passphrase: Optional[BytesOrStr], proxy_command: DefTuple[_ProxyCommand], family: DefTuple[int], local_addr: DefTuple[HostPort], tcp_keepalive: DefTuple[bool], + canonicalize_hostname: DefTuple[Union[bool, str]], + canonical_domains: DefTuple[Sequence[str]], + canonicalize_fallback_local: DefTuple[bool], + canonicalize_max_dots: DefTuple[int], + canonicalize_permitted_cnames: _CNAMEArg, kex_algs: _AlgsArg, encryption_algs: _AlgsArg, mac_algs: _AlgsArg, compression_algs: _AlgsArg, signature_algs: _AlgsArg, host_based_auth: _AuthArg, @@ -6222,6 +7344,20 @@ def prepare(self, config: SSHConfig, # type: ignore keepalive_count_max: int) -> None: """Prepare common connection configuration options""" + def _split_cname_patterns( + patterns: Union[str, Tuple[str, str]]) -> Tuple[str, str]: + """Split CNAME patterns""" + + if isinstance(patterns, str): + domains = patterns.split(':') + + if len(domains) == 2: + patterns = cast(Tuple[str, str], tuple(domains)) + else: + raise ValueError('CNAME rules must contain two patterns') + + return patterns + self.config = config self.protocol_factory = protocol_factory self.version = _validate_version(version) @@ -6233,11 +7369,13 @@ def prepare(self, config: SSHConfig, # type: ignore self.tunnel = tunnel if tunnel != () else config.get('ProxyJump') self.passphrase = passphrase + if proxy_command == (): + proxy_command = cast(Optional[str], config.get('ProxyCommand')) + if isinstance(proxy_command, str): - proxy_command = shlex.split(proxy_command) + proxy_command = split_args(proxy_command) - self.proxy_command = proxy_command if proxy_command != () else \ - cast(Sequence[str], config.get('ProxyCommand')) + self.proxy_command = proxy_command self.family = cast(int, family if family != () else config.get('AddressFamily', socket.AF_UNSPEC)) @@ -6250,6 +7388,32 @@ def prepare(self, config: SSHConfig, # type: ignore self.tcp_keepalive = cast(bool, tcp_keepalive if tcp_keepalive != () else config.get('TCPKeepAlive', True)) + self.canonicalize_hostname = \ + cast(Union[bool, str], canonicalize_hostname + if canonicalize_hostname != () + else config.get('CanonicalizeHostname', False)) + + self.canonical_domains = \ + cast(Sequence[str], canonical_domains if canonical_domains != () + else config.get('CanonicalDomains', ())) + + self.canonicalize_fallback_local = \ + cast(bool, canonicalize_fallback_local \ + if canonicalize_fallback_local != () + else config.get('CanonicalizeFallbackLocal', True)) + + self.canonicalize_max_dots = \ + cast(int, canonicalize_max_dots if canonicalize_max_dots != () + else config.get('CanonicalizeMaxDots', 1)) + + permitted_cnames = \ + cast(Sequence[str], canonicalize_permitted_cnames + if canonicalize_permitted_cnames != () + else config.get('CanonicalizePermittedCNAMEs', ())) + + self.canonicalize_permitted_cnames = \ + [_split_cname_patterns(patterns) for patterns in permitted_cnames] + self.kex_algs, self.encryption_algs, self.mac_algs, \ self.compression_algs, self.signature_algs = \ _validate_algs(config, kex_algs, encryption_algs, mac_algs, @@ -6286,7 +7450,8 @@ def prepare(self, config: SSHConfig, # type: ignore if x509_trusted_cert_paths: for path in x509_trusted_cert_paths: if not Path(path).is_dir(): - raise ValueError('Path not a directory: ' + str(path)) + raise ValueError('X.509 trusted certificate path not ' + f'a directory: {path}') self.x509_trusted_certs = x509_trusted_certs self.x509_trusted_cert_paths = x509_trusted_cert_paths @@ -6304,9 +7469,7 @@ def prepare(self, config: SSHConfig, # type: ignore elif isinstance(rekey_bytes, str): rekey_bytes = parse_byte_count(rekey_bytes) - rekey_bytes: int - - if rekey_bytes <= 0: + if cast(int, rekey_bytes) <= 0: raise ValueError('Rekey bytes cannot be negative or zero') if rekey_seconds == (): @@ -6317,9 +7480,7 @@ def prepare(self, config: SSHConfig, # type: ignore elif isinstance(rekey_seconds, str): rekey_seconds = parse_time_interval(rekey_seconds) - rekey_seconds: float - - if rekey_seconds and rekey_seconds <= 0: + if rekey_seconds and cast(float, rekey_seconds) <= 0: raise ValueError('Rekey seconds cannot be negative or zero') if isinstance(connect_timeout, str): @@ -6343,8 +7504,8 @@ def prepare(self, config: SSHConfig, # type: ignore if keepalive_count_max <= 0: raise ValueError('Keepalive count max cannot be negative or zero') - self.rekey_bytes = rekey_bytes - self.rekey_seconds = rekey_seconds + self.rekey_bytes = cast(int, rekey_bytes) + self.rekey_seconds = cast(float, rekey_seconds) self.connect_timeout = connect_timeout or None self.login_timeout = login_timeout self.keepalive_interval = keepalive_interval @@ -6391,6 +7552,17 @@ class SSHClientConnectionOptions(SSHConnectionOptions): caution, as it can result in a host key mismatch if the client trusts only a subset of the host keys the server might return. + :param server_host_keys_handler: (optional) + A `callable` or coroutine handler function which if set will be + called when a global request from the server is received which + provides an updated list of server host keys. The handler takes + four arguments (added, removed, retained, and revoked), each of + which is a list of SSHKey public keys, reflecting differences + between what the server reported and what is currently matching + in known_hosts. + + .. note:: This handler will only be called when known + host checking is enabled and the check succeeded. :param x509_trusted_certs: (optional) A list of certificates which should be trusted for X.509 server certificate authentication. If no trusted certificates are @@ -6424,9 +7596,10 @@ class SSHClientConnectionOptions(SSHConnectionOptions): the currently logged in user on the local machine will be used. :param password: (optional) The password to use for client password authentication or - keyboard-interactive authentication which prompts for a password. - If this is not specified, client password authentication will - not be performed. + keyboard-interactive authentication which prompts for a password, + or a `callable` or coroutine which returns the password to use. + If this is not specified or set to `None`, client password + authentication will not be performed. :param client_host_keysign: (optional) Whether or not to use `ssh-keysign` to sign host-based authentication requests. If set to `True`, an attempt will be @@ -6474,15 +7647,25 @@ class SSHClientConnectionOptions(SSHConnectionOptions): A list of optional certificates which can be paired with the provided client keys. :param passphrase: (optional) - The passphrase to use to decrypt client keys when loading them, - if they are encrypted. If this is not specified, only unencrypted - client keys can be loaded. If the keys passed into client_keys - are already loaded, this argument is ignored. + The passphrase to use to decrypt client keys if they are + encrypted, or a `callable` or coroutine which takes a filename + as a parameter and returns the passphrase to use to decrypt + that file. If not specified, only unencrypted client keys can + be loaded. If the keys passed into client_keys are already + loaded, this argument is ignored. + + .. note:: A callable or coroutine passed in as a passphrase + will be called on all filenames configured as + client keys or client host keys each time an + SSHClientConnectionOptions object is instantiated, + even if the keys aren't encrypted or aren't ever + used for authentication. + :param ignore_encrypted: (optional) Whether or not to ignore encrypted keys when no passphrase is - provided. This is intended to allow encrypted keys specified via - the IdentityFile config option to be ignored if a passphrase - is not specified, loading only unencrypted local keys. Note + specified. This defaults to `True` when keys are specified via + the IdentityFile config option, causing encrypted keys in the + config to be ignored when no passphrase is specified. Note that encrypted keys loaded into an SSH agent can still be used when this option is set. :param host_based_auth: (optional) @@ -6501,12 +7684,14 @@ class SSHClientConnectionOptions(SSHConnectionOptions): :param password_auth: (optional) Whether or not to allow password authentication. By default, password authentication is enabled if a password is specified - or if callbacks to provide a password are made availble. + or if callbacks to provide a password are made available. :param gss_host: (optional) The principal name to use for the host in GSS key exchange and authentication. If not specified, this value will be the same as the `host` argument. If this argument is explicitly set to `None`, GSS key exchange and authentication will not be performed. + :param gss_store: (optional) + The GSS credential store from which to acquire credentials. :param gss_kex: (optional) Whether or not to allow GSS key exchange. By default, GSS key exchange is enabled. @@ -6547,8 +7732,11 @@ class SSHClientConnectionOptions(SSHConnectionOptions): made available for use. This is the default. :param agent_forwarding: (optional) Whether or not to allow forwarding of ssh-agent requests from - processes running on the server. By default, ssh-agent forwarding - requests from the server are not allowed. + processes running on the server. This argument can also be set + to the path of a UNIX domain socket in cases where forwarded + agent requests should be sent to a different path than client + agent requests. By default, forwarding ssh-agent requests from + the server is not allowed. :param pkcs11_provider: (optional) The path of a shared library which should be used as a PKCS#11 provider for accessing keys on PIV security tokens. By default, @@ -6586,16 +7774,17 @@ class SSHClientConnectionOptions(SSHConnectionOptions): :param compression_algs: (optional) A list of compression algorithms to use during the SSH handshake, taken from :ref:`compression algorithms `, or - `None` to disable compression. + `None` to disable compression. The client prefers to disable + compression, but will enable it if the server requires it. :param signature_algs: (optional) A list of public key signature algorithms to use during the SSH handshake, taken from :ref:`signature algorithms `. :param rekey_bytes: (optional) The number of bytes which can be sent before the SSH session - key is renegotiated. This defaults to 1 GB. + key is renegotiated, defaulting to 1 GB. :param rekey_seconds: (optional) The maximum time in seconds before the SSH session key is - renegotiated. This defaults to 1 hour. + renegotiated, defaulting to 1 hour. :param connect_timeout: (optional) The maximum time in seconds allowed to complete an outbound SSH connection. This includes the time to establish the TCP @@ -6621,6 +7810,40 @@ class SSHClientConnectionOptions(SSHConnectionOptions): without getting a response before disconnecting from the server. This defaults to 3, but only applies when keepalive_interval is non-zero. + :param tcp_keepalive: (optional) + Whether or not to enable keepalive probes at the TCP level to + detect broken connections, defaulting to `True`. + :param canonicalize_hostname: (optional) + Whether or not to enable hostname canonicalization, defaulting + to `False`, in which case hostnames are passed as-is to the + system resolver. If set to `True`, requests that don't involve + a proxy tunnel or command will attempt to canonicalize the hostname + using canonical_domains and rules in canonicalize_permitted_cnames. + If set to `'always'`, hostname canonicalization is also applied + to proxied requests. + :param canonical_domains: (optional) + When canonicalize_hostname is set, this specifies list of domain + suffixes in which to search for the hostname. + :param canonicalize_fallback_local: (optional) + Whether or not to fall back to looking up the hostname against + the system resolver's search domains when no matches are found + in canonical_domains, defaulting to `True`. + :param canonicalize_max_dots: (optional) + Tha maximum number of dots which can appear in a hostname + before hostname canonicalization is disabled, defaulting + to 1. Hostnames with more than this number of dots are + treated as already being fully qualified and passed as-is + to the system resolver. + :param canonicalize_permitted_cnames: (optional) + Patterns to match against to decide whether hostname + canonicalization should return a CNAME. This argument + contains a list of pairs of wildcard pattern lists. The + first pattern is matched against the hostname found after + adding one of the search domains from canonical_domains and + the second pattern is matched against the associated CNAME. + If a match can be found in the list for both patterns, the + CNAME is returned as the canonical hostname. The default + is an empty list, preventing CNAMEs from being returned. :param command: (optional) The default remote command to execute on client sessions. An interactive shell is started if no command or subsystem is @@ -6691,7 +7914,7 @@ class SSHClientConnectionOptions(SSHConnectionOptions): :param config: (optional) Paths to OpenSSH client configuration files to load. This configuration will be used as a fallback to override the - defaults for settings which are not explcitly specified using + defaults for settings which are not explicitly specified using AsyncSSH's configuration options. .. note:: Specifying configuration files when creating an @@ -6710,11 +7933,12 @@ class SSHClientConnectionOptions(SSHConnectionOptions): build up a configuration. When an option is not explicitly specified, its value will be pulled from this options object (if present) before falling back to the default value. - :type client_factory: `callable` returning :class:`SSHClientConnection` + :type client_factory: `callable` returning :class:`SSHClient` :type proxy_command: `str` or `list` of `str` :type known_hosts: *see* :ref:`SpecifyingKnownHosts` :type host_key_alias: `str` :type server_host_key_algs: `str` or `list` of `str` + :type server_host_keys_handler: `callable` or coroutine :type x509_trusted_certs: *see* :ref:`SpecifyingCertificates` :type x509_trusted_cert_paths: `list` of `str` :type x509_purposes: *see* :ref:`SpecifyingX509Purposes` @@ -6735,6 +7959,8 @@ class SSHClientConnectionOptions(SSHConnectionOptions): :type kbdint_auth: `bool` :type password_auth: `bool` :type gss_host: `str` + :type gss_store: + `str`, `bytes`, or a `dict` with `str` or `bytes` keys and values :type gss_kex: `bool` :type gss_auth: `bool` :type gss_delegate_creds: `bool` @@ -6743,8 +7969,8 @@ class SSHClientConnectionOptions(SSHConnectionOptions): :type agent_path: `str` :type agent_identities: *see* :ref:`SpecifyingPublicKeys` and :ref:`SpecifyingCertificates` - :type agent_forwarding: `bool` - :type pkcs11_provider: `str` + :type agent_forwarding: `bool` or `str` + :type pkcs11_provider: `str` or `None` :type pkcs11_pin: `str` :type client_version: `str` :type kex_algs: `str` or `list` of `str` @@ -6758,6 +7984,12 @@ class SSHClientConnectionOptions(SSHConnectionOptions): :type login_timeout: *see* :ref:`SpecifyingTimeIntervals` :type keepalive_interval: *see* :ref:`SpecifyingTimeIntervals` :type keepalive_count_max: `int` + :type tcp_keepalive: `bool` + :type canonicalize_hostname: `bool` or `'always'` + :type canonical_domains: `list` of `str` + :type canonicalize_fallback_local: `bool` + :type canonicalize_max_dots: `int` + :type canonicalize_permitted_cnames: `list` of `tuple` of 2 `str` values :type command: `str` :type subsystem: `str` :type env: `dict` with `str` keys and values @@ -6785,6 +8017,7 @@ class SSHClientConnectionOptions(SSHConnectionOptions): known_hosts: KnownHostsArg host_key_alias: Optional[str] server_host_key_algs: Union[str, Sequence[str]] + server_host_keys_handler: _ServerHostKeysHandler username: str password: Optional[str] client_host_keysign: Optional[str] @@ -6796,6 +8029,7 @@ class SSHClientConnectionOptions(SSHConnectionOptions): client_certs: Sequence[FilePath] ignore_encrypted: bool gss_host: DefTuple[Optional[str]] + gss_store: Optional[Dict[BytesOrStr, BytesOrStr]] gss_kex: bool gss_auth: bool gss_delegate_creds: bool @@ -6804,12 +8038,12 @@ class SSHClientConnectionOptions(SSHConnectionOptions): agent_path: Optional[str] agent_identities: Optional[Sequence[bytes]] agent_forward_path: Optional[str] - pkcs11_provider: Optional[FilePath] + pkcs11_provider: Optional[str] pkcs11_pin: Optional[str] command: Optional[str] subsystem: Optional[str] - env: _Env - send_env: _SendEnv + env: Optional[Env] + send_env: Optional[EnvSeq] request_pty: _RequestPTY term_type: Optional[str] term_size: TermSizeArg @@ -6824,8 +8058,11 @@ class SSHClientConnectionOptions(SSHConnectionOptions): max_pktsize: int # pylint: disable=arguments-differ - def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore + def prepare(self, # type: ignore + loop: Optional[asyncio.AbstractEventLoop] = None, + last_config: Optional[SSHConfig] = None, config: DefTuple[ConfigPaths] = None, reload: bool = False, + canonical: bool = False, final: bool = False, client_factory: Optional[_ClientFactory] = None, client_version: _VersionArg = (), host: str = '', port: DefTuple[int] = (), tunnel: object = (), @@ -6833,6 +8070,11 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore family: DefTuple[int] = (), local_addr: DefTuple[HostPort] = (), tcp_keepalive: DefTuple[bool] = (), + canonicalize_hostname: DefTuple[Union[bool, str]] = (), + canonical_domains: DefTuple[Sequence[str]] = (), + canonicalize_fallback_local: DefTuple[bool] = (), + canonicalize_max_dots: DefTuple[int] = (), + canonicalize_permitted_cnames: DefTuple[Sequence[str]] = (), kex_algs: _AlgsArg = (), encryption_algs: _AlgsArg = (), mac_algs: _AlgsArg = (), compression_algs: _AlgsArg = (), signature_algs: _AlgsArg = (), host_based_auth: _AuthArg = (), @@ -6850,29 +8092,32 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore known_hosts: KnownHostsArg = (), host_key_alias: DefTuple[Optional[str]] = (), server_host_key_algs: _AlgsArg = (), + server_host_keys_handler: _ServerHostKeysHandler = None, username: DefTuple[str] = (), password: Optional[str] = None, client_host_keysign: DefTuple[KeySignPath] = (), - client_host_keys: _ClientKeysArg = None, + client_host_keys: Optional[_ClientKeysArg] = None, client_host_certs: Sequence[FilePath] = (), client_host: Optional[str] = None, client_username: DefTuple[str] = (), client_keys: _ClientKeysArg = (), client_certs: Sequence[FilePath] = (), passphrase: Optional[BytesOrStr] = None, - ignore_encrypted: bool = False, + ignore_encrypted: DefTuple[bool] = (), gss_host: DefTuple[Optional[str]] = (), + gss_store: Optional[Union[BytesOrStr, BytesOrStrDict]] = None, gss_kex: DefTuple[bool] = (), gss_auth: DefTuple[bool] = (), gss_delegate_creds: DefTuple[bool] = (), preferred_auth: DefTuple[Union[str, Sequence[str]]] = (), disable_trivial_auth: bool = False, agent_path: DefTuple[Optional[str]] = (), agent_identities: DefTuple[Optional[IdentityListArg]] = (), - agent_forwarding: DefTuple[bool] = (), - pkcs11_provider: DefTuple[Optional[FilePath]] = (), + agent_forwarding: DefTuple[Union[bool, str]] = (), + pkcs11_provider: DefTuple[Optional[str]] = (), pkcs11_pin: Optional[str] = None, command: DefTuple[Optional[str]] = (), - subsystem: Optional[str] = None, env: DefTuple[_Env] = (), - send_env: DefTuple[_SendEnv] = (), + subsystem: Optional[str] = None, + env: DefTuple[Optional[Env]] = (), + send_env: DefTuple[Optional[EnvSeq]] = (), request_pty: DefTuple[_RequestPTY] = (), term_type: Optional[str] = None, term_size: TermSizeArg = None, @@ -6898,8 +8143,9 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore config = [default_config] if os.access(default_config, os.R_OK) else [] - config = SSHClientConfig.load(last_config, config, reload, - local_username, username, host, port) + config = SSHClientConfig.load(last_config, config, reload, canonical, + final, local_username, username, host, + port) if x509_trusted_certs == (): default_x509_certs = Path('~', '.ssh', 'ca-bundle.crt').expanduser() @@ -6935,17 +8181,27 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore super().prepare(config, client_factory or SSHClient, client_version, host, port, tunnel, passphrase, proxy_command, family, - local_addr, tcp_keepalive, kex_algs, encryption_algs, - mac_algs, compression_algs, signature_algs, - host_based_auth, public_key_auth, kbdint_auth, - password_auth, x509_trusted_certs, + local_addr, tcp_keepalive, canonicalize_hostname, + canonical_domains, canonicalize_fallback_local, + canonicalize_max_dots, canonicalize_permitted_cnames, + kex_algs, encryption_algs, mac_algs, compression_algs, + signature_algs, host_based_auth, public_key_auth, + kbdint_auth, password_auth, x509_trusted_certs, x509_trusted_cert_paths, x509_purposes, rekey_bytes, rekey_seconds, connect_timeout, login_timeout, keepalive_interval, keepalive_count_max) - self.known_hosts = known_hosts if known_hosts != () else \ - (cast(List[str], config.get('UserKnownHostsFile', [])) + - cast(List[str], config.get('GlobalKnownHostsFile', []))) or () + if known_hosts != (): + self.known_hosts = known_hosts + else: + user_known_hosts = \ + cast(List[str], config.get('UserKnownHostsFile', ())) + + if user_known_hosts == []: + self.known_hosts = None + else: + self.known_hosts = list(user_known_hosts) + \ + cast(List[str], config.get('GlobalKnownHostsFile', [])) self.host_key_alias = \ cast(Optional[str], host_key_alias if host_key_alias != () else @@ -6958,6 +8214,8 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore _select_host_key_algs(server_host_key_algs, cast(DefTuple[str], config.get('HostKeyAlgorithms', ())), []) + self.server_host_keys_handler = server_host_keys_handler + self.username = saslprep(cast(str, username if username != () else config.get('User', local_username))) @@ -6980,7 +8238,7 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore self.client_host_keypairs = \ load_keypairs(cast(KeyPairListArg, client_host_keys), - passphrase, client_host_certs) + passphrase, client_host_certs, loop=loop) self.client_host_keysign = client_host_keysign self.client_host = client_host @@ -6990,6 +8248,11 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore self.gss_host = gss_host + if isinstance(gss_store, (bytes, str)): + self.gss_store = {'ccache': gss_store} + else: + self.gss_store = gss_store + self.gss_kex = cast(bool, gss_kex if gss_kex != () else config.get('GSSAPIKeyExchange', True)) @@ -7017,16 +8280,20 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore agent_path = cast(DefTuple[str], config.get('IdentityAgent', ())) if agent_path == (): - agent_path = \ - cast(DefTuple[str], os.environ.get('SSH_AUTH_SOCK', None)) + agent_path = os.environ.get('SSH_AUTH_SOCK', '') - agent_path = str(Path(agent_path).expanduser()) if agent_path else None + agent_path = str(Path(agent_path).expanduser()) if agent_path else '' if pkcs11_provider == (): pkcs11_provider = \ - cast(Optional[FilePath], config.get('PKCS11Provider')) + cast(Optional[str], config.get('PKCS11Provider')) + + pkcs11_provider: Optional[str] - pkcs11_provider: Optional[FilePath] + if ignore_encrypted == (): + ignore_encrypted = client_keys == () + + ignore_encrypted: bool if client_keys == (): client_keys = cast(_ClientKeysArg, config.get('IdentityFile', ())) @@ -7054,7 +8321,8 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore if client_keys: self.client_keys = \ load_keypairs(cast(KeyPairListArg, client_keys), passphrase, - client_certs, identities_only, ignore_encrypted) + client_certs, identities_only, ignore_encrypted, + loop=loop) elif client_keys is not None: self.client_keys = load_default_keypairs(passphrase, client_certs) else: @@ -7070,18 +8338,27 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore self.pkcs11_pin = None if agent_forwarding == (): - agent_forwarding = cast(bool, config.get('ForwardAgent', False)) + agent_forwarding = cast(Union[bool, str], + config.get('ForwardAgent', False)) + + agent_forwarding: Union[bool, str] - self.agent_forward_path = agent_path if agent_forwarding else None + if not agent_forwarding: + self.agent_forward_path = None + elif agent_forwarding is True: + self.agent_forward_path = agent_path + else: + self.agent_forward_path = agent_forwarding self.command = cast(Optional[str], command if command != () else config.get('RemoteCommand')) self.subsystem = subsystem - self.env = cast(_Env, env if env != () else config.get('SetEnv')) + self.env = cast(Optional[Env], env if env != () else + config.get('SetEnv')) - self.send_env = cast(_SendEnv, send_env if send_env != () else + self.send_env = cast(Optional[EnvSeq], send_env if send_env != () else config.get('SendEnv')) self.request_pty = cast(_RequestPTY, request_pty if request_pty != () @@ -7128,12 +8405,35 @@ class SSHServerConnectionOptions(SSHConnectionOptions): :param server_host_certs: (optional) A list of optional certificates which can be paired with the provided server host keys. + :param send_server_host_keys: (optional) + Whether or not to send a list of the allowed server host keys + for clients to use to update their known hosts like for the + server. + + .. note:: Enabling this option will allow multiple server + host keys of the same type to be configured. Only + the first key of each type will be actively used + during key exchange, but the others will be + reported as reserved keys that clients should + begin to trust, to allow for future key rotation. + If this option is disabled, specifying multiple + server host keys of the same type is treated as + a configuration error. :param passphrase: (optional) - The passphrase to use to decrypt server host keys when loading - them, if they are encrypted. If this is not specified, only - unencrypted server host keys can be loaded. If the keys passed - into server_host_keys are already loaded, this argument is - ignored. + The passphrase to use to decrypt server host keys if they are + encrypted, or a `callable` or coroutine which takes a filename + as a parameter and returns the passphrase to use to decrypt + that file. If not specified, only unencrypted server host keys + can be loaded. If the keys passed into server_host_keys are + already loaded, this argument is ignored. + + .. note:: A callable or coroutine passed in as a passphrase + will be called on all filenames configured as + server host keys each time an + SSHServerConnectionOptions object is instantiated, + even if the keys aren't encrypted or aren't ever + used for server validation. + :param known_client_hosts: (optional) A list of client hosts which should be trusted to perform host-based client authentication. If this is not specified, @@ -7146,7 +8446,7 @@ class SSHServerConnectionOptions(SSHConnectionOptions): client connected from. :param authorized_client_keys: (optional) A list of authorized user and CA public keys which should be - trusted for certifcate-based client public key authentication. + trusted for certificate-based client public key authentication. :param x509_trusted_certs: (optional) A list of certificates which should be trusted for X.509 client certificate authentication. If this argument is explicitly set @@ -7198,6 +8498,8 @@ class SSHServerConnectionOptions(SSHConnectionOptions): name. Otherwise, the value used by :func:`socket.getfqdn` will be used. If this argument is explicitly set to `None`, GSS key exchange and authentication will not be performed. + :param gss_store: (optional) + The GSS credential store from which to acquire credentials. :param gss_kex: (optional) Whether or not to allow GSS key exchange. By default, GSS key exchange is enabled. @@ -7262,11 +8564,11 @@ class SSHServerConnectionOptions(SSHConnectionOptions): errors of data exchanged on sessions on this server, defaulting to 'strict'. :param sftp_factory: (optional) - A `callable` which returns an :class:`SFTPServer` object that - will be created each time an SFTP session is requested by the - client, or `True` to use the base :class:`SFTPServer` class - to handle SFTP requests. If not specified, SFTP sessions are - rejected by default. + A `callable` or coroutine which returns an :class:`SFTPServer` + object that will be created each time an SFTP session is + requested by the client, or `True` to use the base + :class:`SFTPServer` class to handle SFTP requests. If not + specified, SFTP sessions are rejected by default. :param sftp_version: (optional) The maximum version of the SFTP protocol to support, currently either 3 or 4, defaulting to 3. @@ -7285,26 +8587,28 @@ class SSHServerConnectionOptions(SSHConnectionOptions): this server, defaulting to `'AsyncSSH'` and its version number. :param kex_algs: (optional) A list of allowed key exchange algorithms in the SSH handshake, - taken from :ref:`key exchange algorithms ` + taken from :ref:`key exchange algorithms `, :param encryption_algs: (optional) A list of encryption algorithms to use during the SSH handshake, - taken from :ref:`encryption algorithms ` + taken from :ref:`encryption algorithms `. :param mac_algs: (optional) A list of MAC algorithms to use during the SSH handshake, taken - from :ref:`MAC algorithms ` + from :ref:`MAC algorithms `. :param compression_algs: (optional) A list of compression algorithms to use during the SSH handshake, taken from :ref:`compression algorithms `, or - `None` to disable compression + `None` to disable compression. The server defaults to allowing + either no compression or compression after auth, depending on + what the client requests. :param signature_algs: (optional) A list of public key signature algorithms to use during the SSH - handshake, taken from :ref:`signature algorithms ` + handshake, taken from :ref:`signature algorithms `. :param rekey_bytes: (optional) The number of bytes which can be sent before the SSH session - key is renegotiated, defaulting to 1 GB + key is renegotiated, defaulting to 1 GB. :param rekey_seconds: (optional) The maximum time in seconds before the SSH session key is - renegotiated, defaulting to 1 hour + renegotiated, defaulting to 1 hour. :param connect_timeout: (optional) The maximum time in seconds allowed to complete an outbound SSH connection. This includes the time to establish the TCP @@ -7314,8 +8618,8 @@ class SSHServerConnectionOptions(SSHConnectionOptions): and AsyncSSH's login timeout. :param login_timeout: (optional) The maximum time in seconds allowed for authentication to - complete, defaulting to 2 minutes. Setting this to 0 - will disable the login timeout. + complete, defaulting to 2 minutes. Setting this to 0 will + disable the login timeout. .. note:: This timeout only applies after the SSH TCP connection is established. To set a timeout @@ -7332,11 +8636,32 @@ class SSHServerConnectionOptions(SSHConnectionOptions): non-zero. :param tcp_keepalive: (optional) Whether or not to enable keepalive probes at the TCP level to - detect broken connections, defaulting to `True` + detect broken connections, defaulting to `True`. + :param canonicalize_hostname: (optional) + Whether or not to enable hostname canonicalization, defaulting + to `False`, in which case hostnames are passed as-is to the + system resolver. If set to `True`, requests that don't involve + a proxy tunnel or command will attempt to canonicalize the hostname + using canonical_domains and rules in canonicalize_permitted_cnames. + If set to `'always'`, hostname canonicalization is also applied + to proxied requests. + :param canonical_domains: (optional) + When canonicalize_hostname is set, this specifies list of domain + suffixes in which to search for the hostname. + :param canonicalize_fallback_local: (optional) + Whether or not to fall back to looking up the hostname against + the system resolver's search domains when no matches are found + in canonical_domains, defaulting to `True`. + :param canonicalize_max_dots: (optional) + Tha maximum number of dots which can appear in a hostname + before hostname canonicalization is disabled, defaulting + to 1. Hostnames with more than this number of dots are + treated as already being fully qualified and passed as-is + to the system resolver. :param config: (optional) Paths to OpenSSH server configuration files to load. This configuration will be used as a fallback to override the - defaults for settings which are not explcitly specified using + defaults for settings which are not explicitly specified using AsyncSSH's configuration options. .. note:: Specifying configuration files when creating an @@ -7355,11 +8680,12 @@ class SSHServerConnectionOptions(SSHConnectionOptions): build up a configuration. When an option is not explicitly specified, its value will be pulled from this options object (if present) before falling back to the default value. - :type server_factory: `callable` returning :class:`SSHServerConnection` + :type server_factory: `callable` returning :class:`SSHServer` :type proxy_command: `str` or `list` of `str` :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type server_host_keys: *see* :ref:`SpecifyingPrivateKeys` :type server_host_certs: *see* :ref:`SpecifyingCertificates` + :type send_server_host_keys: `bool` :type passphrase: `str` or `bytes` :type known_client_hosts: *see* :ref:`SpecifyingKnownHosts` :type trust_client_host: `bool` @@ -7372,6 +8698,8 @@ class SSHServerConnectionOptions(SSHConnectionOptions): :type kbdint_auth: `bool` :type password_auth: `bool` :type gss_host: `str` + :type gss_store: + `str`, `bytes`, or a `dict` with `str` or `bytes` keys and values :type gss_kex: `bool` :type gss_auth: `bool` :type allow_pty: `bool` @@ -7383,11 +8711,11 @@ class SSHServerConnectionOptions(SSHConnectionOptions): :type x11_forwarding: `bool` :type x11_auth_path: `str` :type agent_forwarding: `bool` - :type process_factory: `callable` - :type session_factory: `callable` + :type process_factory: `callable` or coroutine + :type session_factory: `callable` or coroutine :type encoding: `str` or `None` :type errors: `str` - :type sftp_factory: `callable` + :type sftp_factory: `callable` or coroutine :type sftp_version: `int` :type allow_scp: `bool` :type window: `int` @@ -7404,6 +8732,12 @@ class SSHServerConnectionOptions(SSHConnectionOptions): :type login_timeout: *see* :ref:`SpecifyingTimeIntervals` :type keepalive_interval: *see* :ref:`SpecifyingTimeIntervals` :type keepalive_count_max: `int` + :type tcp_keepalive: `bool` + :type canonicalize_hostname: `bool` or `'always'` + :type canonical_domains: `list` of `str` + :type canonicalize_fallback_local: `bool` + :type canonicalize_max_dots: `int` + :type canonicalize_permitted_cnames: `list` of `tuple` of 2 `str` values :type config: `list` of `str` :type options: :class:`SSHServerConnectionOptions` @@ -7413,10 +8747,13 @@ class SSHServerConnectionOptions(SSHConnectionOptions): server_factory: _ServerFactory server_version: bytes server_host_keys: 'OrderedDict[bytes, SSHKeyPair]' + all_server_host_keys: 'OrderedDict[bytes, SSHKeyPair]' + send_server_host_keys: bool known_client_hosts: KnownHostsArg trust_client_host: bool authorized_client_keys: DefTuple[Optional[SSHAuthorizedKeys]] gss_host: Optional[str] + gss_store: Optional[Dict[BytesOrStr, BytesOrStr]] gss_kex: bool gss_auth: bool allow_pty: bool @@ -7439,8 +8776,11 @@ class SSHServerConnectionOptions(SSHConnectionOptions): max_pktsize: int # pylint: disable=arguments-differ - def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore + def prepare(self, # type: ignore + loop: Optional[asyncio.AbstractEventLoop] = None, + last_config: Optional[SSHConfig] = None, config: DefTuple[ConfigPaths] = None, reload: bool = False, + canonical: bool = False, final: bool = False, accept_addr: str = '', accept_port: int = 0, username: str = '', client_host: str = '', client_addr: str = '', @@ -7451,6 +8791,11 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore family: DefTuple[int] = (), local_addr: DefTuple[HostPort] = (), tcp_keepalive: DefTuple[bool] = (), + canonicalize_hostname: DefTuple[Union[bool, str]] = (), + canonical_domains: DefTuple[Sequence[str]] = (), + canonicalize_fallback_local: DefTuple[bool] = (), + canonicalize_max_dots: DefTuple[int] = (), + canonicalize_permitted_cnames: DefTuple[Sequence[str]] = (), kex_algs: _AlgsArg = (), encryption_algs: _AlgsArg = (), mac_algs: _AlgsArg = (), compression_algs: _AlgsArg = (), signature_algs: _AlgsArg = (), host_based_auth: _AuthArg = (), @@ -7467,11 +8812,13 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore keepalive_count_max: DefTuple[int] = (), server_host_keys: KeyPairListArg = (), server_host_certs: CertListArg = (), + send_server_host_keys: bool = False, passphrase: Optional[BytesOrStr] = None, known_client_hosts: KnownHostsArg = None, trust_client_host: bool = False, authorized_client_keys: _AuthKeysArg = (), gss_host: DefTuple[Optional[str]] = (), + gss_store: Optional[Union[BytesOrStr, BytesOrStrDict]] = None, gss_kex: DefTuple[bool] = (), gss_auth: DefTuple[bool] = (), allow_pty: DefTuple[bool] = (), @@ -7492,8 +8839,8 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore max_pktsize: int = _DEFAULT_MAX_PKTSIZE) -> None: """Prepare server connection configuration options""" - config = SSHServerConfig.load(last_config, config, reload, - accept_addr, accept_port, username, + config = SSHServerConfig.load(last_config, config, reload, canonical, + final, accept_addr, accept_port, username, client_host, client_addr) if login_timeout == (): @@ -7519,10 +8866,12 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore super().prepare(config, server_factory or SSHServer, server_version, host, port, tunnel, passphrase, proxy_command, family, - local_addr, tcp_keepalive, kex_algs, encryption_algs, - mac_algs, compression_algs, signature_algs, - host_based_auth, public_key_auth, kbdint_auth, - password_auth, x509_trusted_certs, + local_addr, tcp_keepalive, canonicalize_hostname, + canonical_domains, canonicalize_fallback_local, + canonicalize_max_dots, canonicalize_permitted_cnames, + kex_algs, encryption_algs, mac_algs, compression_algs, + signature_algs, host_based_auth, public_key_auth, + kbdint_auth, password_auth, x509_trusted_certs, x509_trusted_cert_paths, x509_purposes, rekey_bytes, rekey_seconds, connect_timeout, login_timeout, keepalive_interval, keepalive_count_max) @@ -7535,17 +8884,24 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore config.get('HostCertificate', ())) server_keys = load_keypairs(server_host_keys, passphrase, - server_host_certs) + server_host_certs, loop=loop) self.server_host_keys = OrderedDict() + self.all_server_host_keys = OrderedDict() for keypair in server_keys: for alg in keypair.host_key_algorithms: - if alg in self.server_host_keys: - raise ValueError('Multiple keys of type %s found' % - alg.decode('ascii')) + if alg in self.server_host_keys and not send_server_host_keys: + raise ValueError('Multiple keys of type ' + f'{alg.decode("ascii")} found: ' + 'Enable send_server_host_keys to ' + 'allow reserved keys to be configured') - self.server_host_keys[alg] = keypair + if alg not in self.server_host_keys: + self.server_host_keys[alg] = keypair + + if send_server_host_keys: + self.all_server_host_keys[keypair.public_data] = keypair self.known_client_hosts = known_client_hosts self.trust_client_host = trust_client_host @@ -7570,6 +8926,11 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore self.gss_host = gss_host + if isinstance(gss_store, (bytes, str)): + self.gss_store = {'ccache': gss_store} + else: + self.gss_store = gss_store + self.gss_kex = cast(bool, gss_kex if gss_kex != () else config.get('GSSAPIKeyExchange', True)) @@ -7609,10 +8970,114 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore @async_context_manager -async def connect(host: str, port: DefTuple[int] = (), *, +async def run_client(sock: socket.socket, config: DefTuple[ConfigPaths] = (), + options: Optional[SSHClientConnectionOptions] = None, + **kwargs: object) -> SSHClientConnection: + """Start an SSH client connection on an already-connected socket + + This function is a coroutine which starts an SSH client on an + existing already-connected socket. It can be used instead of + :func:`connect` when a socket is connected outside of asyncio. + + :param sock: + An existing already-connected socket to run an SSH client on, + instead of opening up a new connection. + :param config: (optional) + Paths to OpenSSH client configuration files to load. This + configuration will be used as a fallback to override the + defaults for settings which are not explicitly specified using + AsyncSSH's configuration options. If no paths are specified and + no config paths were set when constructing the `options` + argument (if any), an attempt will be made to load the + configuration from the file :file:`.ssh/config`. If this + argument is explicitly set to `None`, no new configuration + files will be loaded, but any configuration loaded when + constructing the `options` argument will still apply. See + :ref:`SupportedClientConfigOptions` for details on what + configuration options are currently supported. + :param options: (optional) + Options to use when establishing the SSH client connection. These + options can be specified either through this parameter or as direct + keyword arguments to this function. + :type sock: :class:`socket.socket` + :type config: `list` of `str` + :type options: :class:`SSHClientConnectionOptions` + + :returns: :class:`SSHClientConnection` + + """ + + def conn_factory() -> SSHClientConnection: + """Return an SSH client connection factory""" + + return SSHClientConnection(loop, new_options, wait='auth') + + loop = asyncio.get_event_loop() + + new_options = await SSHClientConnectionOptions.construct( + options, config=config, **kwargs) + + return await asyncio.wait_for( + _connect(new_options, config, loop, 0, sock, conn_factory, + 'Starting SSH client on'), + timeout=new_options.connect_timeout) + + +@async_context_manager +async def run_server(sock: socket.socket, config: DefTuple[ConfigPaths] = (), + options: Optional[SSHServerConnectionOptions] = None, + **kwargs: object) -> SSHServerConnection: + """Start an SSH server connection on an already-connected socket + + This function is a coroutine which starts an SSH server on an + existing already-connected TCP socket. It can be used instead of + :func:`listen` when connections are accepted outside of asyncio. + + :param sock: + An existing already-connected socket to run SSH over, instead of + opening up a new connection. + :param config: (optional) + Paths to OpenSSH server configuration files to load. This + configuration will be used as a fallback to override the + defaults for settings which are not explicitly specified using + AsyncSSH's configuration options. By default, no OpenSSH + configuration files will be loaded. See + :ref:`SupportedServerConfigOptions` for details on what + configuration options are currently supported. + :param options: (optional) + Options to use when starting the reverse-direction SSH server. + These options can be specified either through this parameter + or as direct keyword arguments to this function. + :type sock: :class:`socket.socket` + :type config: `list` of `str` + :type options: :class:`SSHServerConnectionOptions` + + :returns: :class:`SSHServerConnection` + + """ + + def conn_factory() -> SSHServerConnection: + """Return an SSH server connection factory""" + + return SSHServerConnection(loop, new_options, wait='auth') + + loop = asyncio.get_event_loop() + + new_options = await SSHServerConnectionOptions.construct( + options, config=config, **kwargs) + + return await asyncio.wait_for( + _connect(new_options, config, loop, 0, sock, conn_factory, + 'Starting SSH server on'), + timeout=new_options.connect_timeout) + + +@async_context_manager +async def connect(host = '', port: DefTuple[int] = (), *, tunnel: DefTuple[_TunnelConnector] = (), family: DefTuple[int] = (), flags: int = 0, local_addr: DefTuple[HostPort] = (), + sock: Optional[socket.socket] = None, config: DefTuple[ConfigPaths] = (), options: Optional[SSHClientConnectionOptions] = None, **kwargs: object) -> SSHClientConnection: @@ -7642,7 +9107,7 @@ async def connect(host: str, port: DefTuple[int] = (), *, If an error occurs, it will be raised as an exception and the partially open connection and client objects will be cleaned up. - :param host: + :param host: (optional) The hostname or address to connect to. :param port: (optional) The port number to connect to. If not specified, the default @@ -7653,8 +9118,20 @@ async def connect(host: str, port: DefTuple[int] = (), *, over this connection to the requested host and port rather than connecting directly via TCP. A string of the form [user@]host[:port] may also be specified, in which case a - connection will first be made to that host and it will then be - used as a tunnel. + connection will be made to that host and then used as a tunnel. + A comma-separated list may also be specified to establish a + tunnel through multiple hosts. + + .. note:: When specifying tunnel as a string, any config + options in the call will apply only when opening + a connection to the final destination host and + port. However, settings to use when opening + tunnels may be specified via a configuration file. + To get more control of config options used to + open the tunnel, :func:`connect` can be called + explicitly, and the resulting client connection + can be passed as the tunnel argument. + :param family: (optional) The address family to use when creating the socket. By default, the address family is automatically selected based on the host. @@ -7662,10 +9139,14 @@ async def connect(host: str, port: DefTuple[int] = (), *, The flags to pass to getaddrinfo() when looking up the host address :param local_addr: (optional) The host and port to bind the socket to before connecting + :param sock: (optional) + An existing already-connected socket to run SSH over, instead of + opening up a new connection. When this is specified, none of + host, port family, flags, or local_addr should be specified. :param config: (optional) Paths to OpenSSH client configuration files to load. This configuration will be used as a fallback to override the - defaults for settings which are not explcitly specified using + defaults for settings which are not explicitly specified using AsyncSSH's configuration options. If no paths are specified and no config paths were set when constructing the `options` argument (if any), an attempt will be made to load the @@ -7685,6 +9166,7 @@ async def connect(host: str, port: DefTuple[int] = (), *, :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type flags: flags to pass to :meth:`getaddrinfo() ` :type local_addr: tuple of `str` and `int` + :type sock: :class:`socket.socket` or `None` :type config: `list` of `str` :type options: :class:`SSHClientConnectionOptions` @@ -7699,23 +9181,23 @@ def conn_factory() -> SSHClientConnection: loop = asyncio.get_event_loop() - new_options = cast(SSHClientConnectionOptions, await _run_in_executor( - loop, SSHClientConnectionOptions, options, config=config, - host=host, port=port, tunnel=tunnel, family=family, - local_addr=local_addr, **kwargs)) + new_options = await SSHClientConnectionOptions.construct( + options, config=config, host=host, port=port, tunnel=tunnel, + family=family, local_addr=local_addr, **kwargs) return await asyncio.wait_for( - _connect(new_options, loop, flags, conn_factory, + _connect(new_options, config, loop, flags, sock, conn_factory, 'Opening SSH connection to'), timeout=new_options.connect_timeout) @async_context_manager async def connect_reverse( - host: str, port: DefTuple[int] = (), *, + host = '', port: DefTuple[int] = (), *, tunnel: DefTuple[_TunnelConnector] = (), family: DefTuple[int] = (), flags: int = 0, local_addr: DefTuple[HostPort] = (), + sock: Optional[socket.socket] = None, config: DefTuple[ConfigPaths] = (), options: Optional[SSHServerConnectionOptions] = None, **kwargs: object) -> SSHServerConnection: @@ -7731,7 +9213,7 @@ async def connect_reverse( that the `options` are of type :class:`SSHServerConnectionOptions` instead of :class:`SSHClientConnectionOptions`. - :param host: + :param host: (optional) The hostname or address to connect to. :param port: (optional) The port number to connect to. If not specified, the default @@ -7742,8 +9224,20 @@ async def connect_reverse( over this connection to the requested host and port rather than connecting directly via TCP. A string of the form [user@]host[:port] may also be specified, in which case a - connection will first be made to that host and it will then be - used as a tunnel. + connection will be made to that host and then used as a tunnel. + A comma-separated list may also be specified to establish a + tunnel through multiple hosts. + + .. note:: When specifying tunnel as a string, any config + options in the call will apply only when opening + a connection to the final destination host and + port. However, settings to use when opening + tunnels may be specified via a configuration file. + To get more control of config options used to + open the tunnel, :func:`connect` can be called + explicitly, and the resulting client connection + can be passed as the tunnel argument. + :param family: (optional) The address family to use when creating the socket. By default, the address family is automatically selected based on the host. @@ -7751,10 +9245,14 @@ async def connect_reverse( The flags to pass to getaddrinfo() when looking up the host address :param local_addr: (optional) The host and port to bind the socket to before connecting + :param sock: (optional) + An existing already-connected socket to run SSH over, instead of + opening up a new connection. When this is specified, none of + host, port family, flags, or local_addr should be specified. :param config: (optional) Paths to OpenSSH server configuration files to load. This configuration will be used as a fallback to override the - defaults for settings which are not explcitly specified using + defaults for settings which are not explicitly specified using AsyncSSH's configuration options. By default, no OpenSSH configuration files will be loaded. See :ref:`SupportedServerConfigOptions` for details on what @@ -7769,6 +9267,7 @@ async def connect_reverse( :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type flags: flags to pass to :meth:`getaddrinfo() ` :type local_addr: tuple of `str` and `int` + :type sock: :class:`socket.socket` or `None` :type config: `list` of `str` :type options: :class:`SSHServerConnectionOptions` @@ -7777,29 +9276,29 @@ async def connect_reverse( """ def conn_factory() -> SSHServerConnection: - """Return an SSH client connection factory""" + """Return an SSH server connection factory""" return SSHServerConnection(loop, new_options, wait='auth') loop = asyncio.get_event_loop() - new_options = cast(SSHServerConnectionOptions, await _run_in_executor( - loop, SSHServerConnectionOptions, options, config=config, - host=host, port=port, tunnel=tunnel, family=family, - local_addr=local_addr, **kwargs)) + new_options = await SSHServerConnectionOptions.construct( + options, config=config, host=host, port=port, tunnel=tunnel, + family=family, local_addr=local_addr, **kwargs) return await asyncio.wait_for( - _connect(new_options, loop, flags, conn_factory, + _connect(new_options, config, loop, flags, sock, conn_factory, 'Opening reverse SSH connection to'), timeout=new_options.connect_timeout) @async_context_manager -async def listen(host: str = '', port: DefTuple[int] = (), *, +async def listen(host = '', port: DefTuple[int] = (), *, tunnel: DefTuple[_TunnelListener] = (), family: DefTuple[int] = (), flags:int = socket.AI_PASSIVE, - backlog: int = 100, reuse_address: bool = False, - reuse_port: bool = False, acceptor: _AcceptHandler = None, + backlog: int = 100, sock: Optional[socket.socket] = None, + reuse_address: bool = False, reuse_port: bool = False, + acceptor: _AcceptHandler = None, error_handler: _ErrorHandler = None, config: DefTuple[ConfigPaths] = (), options: Optional[SSHServerConnectionOptions] = None, @@ -7817,13 +9316,25 @@ async def listen(host: str = '', port: DefTuple[int] = (), *, The port number to listen on. If not specified, the default SSH port is used. :param tunnel: (optional) - An existing SSH client connection that this new listener should - be forwarded over. If set, a remote TCP/IP listener will be - opened on this connection on the requested host and port rather - than listening directly via TCP. A string of the form + An existing SSH client connection that this new connection should + be tunneled over. If set, a direct TCP/IP tunnel will be opened + over this connection to the requested host and port rather than + connecting directly via TCP. A string of the form [user@]host[:port] may also be specified, in which case a - connection will first be made to that host and it will then be - used as a tunnel. + connection will be made to that host and then used as a tunnel. + A comma-separated list may also be specified to establish a + tunnel through multiple hosts. + + .. note:: When specifying tunnel as a string, any config + options in the call will apply only when opening + a connection to the final destination host and + port. However, settings to use when opening + tunnels may be specified via a configuration file. + To get more control of config options used to + open the tunnel, :func:`connect` can be called + explicitly, and the resulting client connection + can be passed as the tunnel argument. + :param family: (optional) The address family to use when creating the server. By default, the address families are automatically selected based on the host. @@ -7831,6 +9342,10 @@ async def listen(host: str = '', port: DefTuple[int] = (), *, The flags to pass to getaddrinfo() when looking up the host :param backlog: (optional) The maximum number of queued connections allowed on listeners + :param sock: (optional) + A pre-existing socket to use instead of creating and binding + a new socket. When this is specified, host and port should not + be specified. :param reuse_address: (optional) Whether or not to reuse a local socket in the TIME_WAIT state without waiting for its natural timeout to expire. If not @@ -7854,7 +9369,7 @@ async def listen(host: str = '', port: DefTuple[int] = (), *, :param config: (optional) Paths to OpenSSH server configuration files to load. This configuration will be used as a fallback to override the - defaults for settings which are not explcitly specified using + defaults for settings which are not explicitly specified using AsyncSSH's configuration options. By default, no OpenSSH configuration files will be loaded. See :ref:`SupportedServerConfigOptions` for details on what @@ -7869,8 +9384,11 @@ async def listen(host: str = '', port: DefTuple[int] = (), *, :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type flags: flags to pass to :meth:`getaddrinfo() ` :type backlog: `int` + :type sock: :class:`socket.socket` or `None` :type reuse_address: `bool` :type reuse_port: `bool` + :type acceptor: `callable` or coroutine + :type error_handler: `callable` :type config: `list` of `str` :type options: :class:`SSHServerConnectionOptions` @@ -7879,30 +9397,31 @@ async def listen(host: str = '', port: DefTuple[int] = (), *, """ def conn_factory() -> SSHServerConnection: - """Return an SSH client connection factory""" + """Return an SSH server connection factory""" return SSHServerConnection(loop, new_options, acceptor, error_handler) loop = asyncio.get_event_loop() - new_options = cast(SSHServerConnectionOptions, await _run_in_executor( - loop, SSHServerConnectionOptions, options, config=config, - host=host, port=port, tunnel=tunnel, family=family, **kwargs)) + new_options = await SSHServerConnectionOptions.construct( + options, config=config, host=host, port=port, tunnel=tunnel, + family=family, **kwargs) # pylint: disable=attribute-defined-outside-init new_options.proxy_command = None return await asyncio.wait_for( - _listen(new_options, loop, flags, backlog, reuse_address, + _listen(new_options, config, loop, flags, backlog, sock, reuse_address, reuse_port, conn_factory, 'Creating SSH listener on'), timeout=new_options.connect_timeout) @async_context_manager -async def listen_reverse(host: str = '', port: DefTuple[int] = (), *, +async def listen_reverse(host = '', port: DefTuple[int] = (), *, tunnel: DefTuple[_TunnelListener] = (), family: DefTuple[int] = (), - flags:int = socket.AI_PASSIVE, backlog: int = 100, + flags: int = socket.AI_PASSIVE, backlog: int = 100, + sock: Optional[socket.socket] = None, reuse_address: bool = False, reuse_port: bool = False, acceptor: _AcceptHandler = None, error_handler: _ErrorHandler = None, @@ -7932,13 +9451,25 @@ async def listen_reverse(host: str = '', port: DefTuple[int] = (), *, The port number to listen on. If not specified, the default SSH port is used. :param tunnel: (optional) - An existing SSH client connection that this new listener should - be forwarded over. If set, a remote TCP/IP listener will be - opened on this connection on the requested host and port rather - than listening directly via TCP. A string of the form + An existing SSH client connection that this new connection should + be tunneled over. If set, a direct TCP/IP tunnel will be opened + over this connection to the requested host and port rather than + connecting directly via TCP. A string of the form [user@]host[:port] may also be specified, in which case a - connection will first be made to that host and it will then be - used as a tunnel. + connection will be made to that host and then used as a tunnel. + A comma-separated list may also be specified to establish a + tunnel through multiple hosts. + + .. note:: When specifying tunnel as a string, any config + options in the call will apply only when opening + a connection to the final destination host and + port. However, settings to use when opening + tunnels may be specified via a configuration file. + To get more control of config options used to + open the tunnel, :func:`connect` can be called + explicitly, and the resulting client connection + can be passed as the tunnel argument. + :param family: (optional) The address family to use when creating the server. By default, the address families are automatically selected based on the host. @@ -7946,6 +9477,9 @@ async def listen_reverse(host: str = '', port: DefTuple[int] = (), *, The flags to pass to getaddrinfo() when looking up the host :param backlog: (optional) The maximum number of queued connections allowed on listeners + :param sock: (optional) + A pre-existing socket to use instead of creating and binding + a new socket. When this is specified, host and port should not :param reuse_address: (optional) Whether or not to reuse a local socket in the TIME_WAIT state without waiting for its natural timeout to expire. If not @@ -7969,7 +9503,7 @@ async def listen_reverse(host: str = '', port: DefTuple[int] = (), *, :param config: (optional) Paths to OpenSSH client configuration files to load. This configuration will be used as a fallback to override the - defaults for settings which are not explcitly specified using + defaults for settings which are not explicitly specified using AsyncSSH's configuration options. If no paths are specified and no config paths were set when constructing the `options` argument (if any), an attempt will be made to load the @@ -7989,8 +9523,11 @@ async def listen_reverse(host: str = '', port: DefTuple[int] = (), *, :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type flags: flags to pass to :meth:`getaddrinfo() ` :type backlog: `int` + :type sock: :class:`socket.socket` or `None` :type reuse_address: `bool` :type reuse_port: `bool` + :type acceptor: `callable` or coroutine + :type error_handler: `callable` :type config: `list` of `str` :type options: :class:`SSHClientConnectionOptions` @@ -8005,22 +9542,22 @@ def conn_factory() -> SSHClientConnection: loop = asyncio.get_event_loop() - new_options = cast(SSHClientConnectionOptions, await _run_in_executor( - loop, SSHClientConnectionOptions, options, config=config, - host=host, port=port, tunnel=tunnel, family=family, **kwargs)) + new_options = await SSHClientConnectionOptions.construct( + options, config=config, host=host, port=port, tunnel=tunnel, + family=family, **kwargs) # pylint: disable=attribute-defined-outside-init new_options.proxy_command = None return await asyncio.wait_for( - _listen(new_options, loop, flags, backlog, + _listen(new_options, config, loop, flags, backlog, sock, reuse_address, reuse_port, conn_factory, 'Creating reverse direction SSH listener on'), timeout=new_options.connect_timeout) async def create_connection(client_factory: _ClientFactory, - host: str, port: DefTuple[int] = (), + host = '', port: DefTuple[int] = (), **kwargs: object) -> \ Tuple[SSHClientConnection, SSHClient]: """Create an SSH client connection @@ -8046,7 +9583,7 @@ async def create_connection(client_factory: _ClientFactory, @async_context_manager async def create_server(server_factory: _ServerFactory, - host: str = '', port: DefTuple[int] = (), + host = '', port: DefTuple[int] = (), **kwargs: object) -> SSHAcceptor: """Create an SSH server @@ -8064,10 +9601,11 @@ async def create_server(server_factory: _ServerFactory, async def get_server_host_key( - host: str, port: DefTuple[int] = (), *, + host = '', port: DefTuple[int] = (), *, tunnel: DefTuple[_TunnelConnector] = (), - proxy_command: DefTuple[str] = (), family: DefTuple[int] = (), + proxy_command: DefTuple[_ProxyCommand] = (), family: DefTuple[int] = (), flags: int = 0, local_addr: DefTuple[HostPort] = (), + sock: Optional[socket.socket] = None, client_version: DefTuple[BytesOrStr] = (), kex_algs: _AlgsArg = (), server_host_key_algs: _AlgsArg = (), config: DefTuple[ConfigPaths] = (), @@ -8089,7 +9627,7 @@ async def get_server_host_key( method may return `None` even when the handshake completes. - :param host: + :param host: (optional) The hostname or address to connect to :param port: (optional) The port number to connect to. If not specified, the default @@ -8100,8 +9638,20 @@ async def get_server_host_key( over this connection to the requested host and port rather than connecting directly via TCP. A string of the form [user@]host[:port] may also be specified, in which case a - connection will first be made to that host and it will then be - used as a tunnel. + connection will be made to that host and then used as a tunnel. + A comma-separated list may also be specified to establish a + tunnel through multiple hosts. + + .. note:: When specifying tunnel as a string, any config + options in the call will apply only when opening + a connection to the final destination host and + port. However, settings to use when opening + tunnels may be specified via a configuration file. + To get more control of config options used to + open the tunnel, :func:`connect` can be called + explicitly, and the resulting client connection + can be passed as the tunnel argument. + :param proxy_command: (optional) A string or list of strings specifying a command and arguments to run to make a connection to the SSH server. Data will be @@ -8115,6 +9665,10 @@ async def get_server_host_key( The flags to pass to getaddrinfo() when looking up the host address :param local_addr: (optional) The host and port to bind the socket to before connecting + :param sock: (optional) + An existing already-connected socket to run SSH over, instead of + opening up a new connection. When this is specified, none of + host, port family, flags, or local_addr should be specified. :param client_version: (optional) An ASCII string to advertise to the SSH server as the version of this client, defaulting to `'AsyncSSH'` and its version number. @@ -8128,7 +9682,7 @@ async def get_server_host_key( :param config: (optional) Paths to OpenSSH client configuration files to load. This configuration will be used as a fallback to override the - defaults for settings which are not explcitly specified using + defaults for settings which are not explicitly specified using AsyncSSH's configuration options. If no paths are specified and no config paths were set when constructing the `options` argument (if any), an attempt will be made to load the @@ -8150,6 +9704,7 @@ async def get_server_host_key( :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type flags: flags to pass to :meth:`getaddrinfo() ` :type local_addr: tuple of `str` and `int` + :type sock: :class:`socket.socket` or `None` :type client_version: `str` :type kex_algs: `str` or `list` of `str` :type server_host_key_algs: `str` or `list` of `str` @@ -8167,16 +9722,16 @@ def conn_factory() -> SSHClientConnection: loop = asyncio.get_event_loop() - new_options = cast(SSHClientConnectionOptions, await _run_in_executor( - loop, SSHClientConnectionOptions, options, config=config, - host=host, port=port, tunnel=tunnel, proxy_command=proxy_command, - family=family, local_addr=local_addr, known_hosts=None, - server_host_key_algs=server_host_key_algs, x509_trusted_certs=None, - x509_trusted_cert_paths=None, x509_purposes='any', gss_host=None, - kex_algs=kex_algs, client_version=client_version)) + new_options = await SSHClientConnectionOptions.construct( + options, config=config, host=host, port=port, tunnel=tunnel, + proxy_command=proxy_command, family=family, local_addr=local_addr, + known_hosts=None, server_host_key_algs=server_host_key_algs, + x509_trusted_certs=None, x509_trusted_cert_paths=None, + x509_purposes='any', gss_host=None, kex_algs=kex_algs, + client_version=client_version) conn = await asyncio.wait_for( - _connect(new_options, loop, flags, conn_factory, + _connect(new_options, config, loop, flags, sock, conn_factory, 'Fetching server host key from'), timeout=new_options.connect_timeout) @@ -8190,10 +9745,11 @@ def conn_factory() -> SSHClientConnection: async def get_server_auth_methods( - host: str, port: DefTuple[int] = (), username: DefTuple[str] = (), *, + host = '', port: DefTuple[int] = (), username: DefTuple[str] = (), *, tunnel: DefTuple[_TunnelConnector] = (), - proxy_command: DefTuple[str] = (), family: DefTuple[int] = (), + proxy_command: DefTuple[_ProxyCommand] = (), family: DefTuple[int] = (), flags: int = 0, local_addr: DefTuple[HostPort] = (), + sock: Optional[socket.socket] = None, client_version: DefTuple[BytesOrStr] = (), kex_algs: _AlgsArg = (), server_host_key_algs: _AlgsArg = (), config: DefTuple[ConfigPaths] = (), @@ -8211,7 +9767,7 @@ async def get_server_auth_methods( want to specify the specific user you would like to get auth methods for. - :param host: + :param host: (optional) The hostname or address to connect to :param port: (optional) The port number to connect to. If not specified, the default @@ -8225,8 +9781,20 @@ async def get_server_auth_methods( over this connection to the requested host and port rather than connecting directly via TCP. A string of the form [user@]host[:port] may also be specified, in which case a - connection will first be made to that host and it will then be - used as a tunnel. + connection will be made to that host and then used as a tunnel. + A comma-separated list may also be specified to establish a + tunnel through multiple hosts. + + .. note:: When specifying tunnel as a string, any config + options in the call will apply only when opening + a connection to the final destination host and + port. However, settings to use when opening + tunnels may be specified via a configuration file. + To get more control of config options used to + open the tunnel, :func:`connect` can be called + explicitly, and the resulting client connection + can be passed as the tunnel argument. + :param proxy_command: (optional) A string or list of strings specifying a command and arguments to run to make a connection to the SSH server. Data will be @@ -8240,6 +9808,10 @@ async def get_server_auth_methods( The flags to pass to getaddrinfo() when looking up the host address :param local_addr: (optional) The host and port to bind the socket to before connecting + :param sock: (optional) + An existing already-connected socket to run SSH over, instead of + opening up a new connection. When this is specified, none of + host, port family, flags, or local_addr should be specified. :param client_version: (optional) An ASCII string to advertise to the SSH server as the version of this client, defaulting to `'AsyncSSH'` and its version number. @@ -8253,7 +9825,7 @@ async def get_server_auth_methods( :param config: (optional) Paths to OpenSSH client configuration files to load. This configuration will be used as a fallback to override the - defaults for settings which are not explcitly specified using + defaults for settings which are not explicitly specified using AsyncSSH's configuration options. If no paths are specified and no config paths were set when constructing the `options` argument (if any), an attempt will be made to load the @@ -8275,6 +9847,7 @@ async def get_server_auth_methods( :type family: `socket.AF_UNSPEC`, `socket.AF_INET`, or `socket.AF_INET6` :type flags: flags to pass to :meth:`getaddrinfo() ` :type local_addr: tuple of `str` and `int` + :type sock: :class:`socket.socket` or `None` :type client_version: `str` :type kex_algs: `str` or `list` of `str` :type server_host_key_algs: `str` or `list` of `str` @@ -8292,17 +9865,17 @@ def conn_factory() -> SSHClientConnection: loop = asyncio.get_event_loop() - new_options = cast(SSHClientConnectionOptions, await _run_in_executor( - loop, SSHClientConnectionOptions, options, config=config, - host=host, port=port, username=username, tunnel=tunnel, - proxy_command=proxy_command, family=family, local_addr=local_addr, - known_hosts=None, server_host_key_algs=server_host_key_algs, + new_options = await SSHClientConnectionOptions.construct( + options, config=config, host=host, port=port, username=username, + tunnel=tunnel, proxy_command=proxy_command, family=family, + local_addr=local_addr, known_hosts=None, + server_host_key_algs=server_host_key_algs, x509_trusted_certs=None, x509_trusted_cert_paths=None, x509_purposes='any', gss_host=None, kex_algs=kex_algs, - client_version=client_version)) + client_version=client_version) conn = await asyncio.wait_for( - _connect(new_options, loop, flags, conn_factory, + _connect(new_options, config, loop, flags, sock, conn_factory, 'Fetching server auth methods from'), timeout=new_options.connect_timeout) diff --git a/asyncssh/crypto/__init__.py b/asyncssh/crypto/__init__.py index a7ad2a8..21ba80a 100644 --- a/asyncssh/crypto/__init__.py +++ b/asyncssh/crypto/__init__.py @@ -24,6 +24,8 @@ from .dsa import DSAPrivateKey, DSAPublicKey +from .dh import DH + from .ec import ECDSAPrivateKey, ECDSAPublicKey, ECDH from .ed import ed25519_available, ed448_available @@ -38,6 +40,8 @@ from .rsa import RSAPrivateKey, RSAPublicKey +from .pq import mlkem_available, sntrup_available, PQDH + # Import chacha20-poly1305 cipher if available from .chacha import ChachaCipher, chacha_available @@ -51,5 +55,18 @@ try: from .x509 import X509Certificate, X509Name, X509NamePattern from .x509 import generate_x509_certificate, import_x509_certificate -except ImportError: # pragma: no cover +except (ImportError, AttributeError): # pragma: no cover pass + +__all__ = [ + 'BasicCipher', 'ChachaCipher', 'CryptoKey', 'Curve25519DH', 'Curve448DH', + 'DH', 'DSAPrivateKey', 'DSAPublicKey', 'ECDH', 'ECDSAPrivateKey', + 'ECDSAPublicKey', 'EdDSAPrivateKey', 'EdDSAPublicKey', 'GCMCipher', 'PQDH', + 'PyCAKey', 'RSAPrivateKey', 'RSAPublicKey', 'chacha_available', + 'curve25519_available', 'curve448_available', 'X509Certificate', + 'X509Name', 'X509NamePattern', 'ed25519_available', 'ed448_available', + 'generate_x509_certificate', 'get_cipher_params', 'import_x509_certificate', + 'lookup_ec_curve_by_params', 'mlkem_available', 'pbkdf2_hmac', + 'register_cipher', 'sntrup_available', 'umac32', 'umac64', 'umac96', + 'umac128' +] diff --git a/asyncssh/crypto/cipher.py b/asyncssh/crypto/cipher.py index 1c70d52..cdbba6a 100644 --- a/asyncssh/crypto/cipher.py +++ b/asyncssh/crypto/cipher.py @@ -1,4 +1,4 @@ -# Copyright (c) 2014-2021 by Ron Frederick and others. +# Copyright (c) 2014-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -20,17 +20,24 @@ """A shim around PyCA for accessing symmetric ciphers needed by AsyncSSH""" +from types import ModuleType from typing import Any, MutableMapping, Optional, Tuple +import warnings from cryptography.exceptions import InvalidTag from cryptography.hazmat.primitives.ciphers import Cipher, CipherContext from cryptography.hazmat.primitives.ciphers.aead import AESGCM -from cryptography.hazmat.primitives.ciphers.algorithms import AES, ARC4 -from cryptography.hazmat.primitives.ciphers.algorithms import Blowfish, CAST5 -from cryptography.hazmat.primitives.ciphers.algorithms import SEED, TripleDES - from cryptography.hazmat.primitives.ciphers.modes import CBC, CTR +import cryptography.hazmat.primitives.ciphers.algorithms as _algs + +_decrepit_algs: Optional[ModuleType] + +try: + import cryptography.hazmat.decrepit.ciphers.algorithms as _decrepit_algs +except ImportError: # pragma: no cover + _decrepit_algs = None + _CipherAlgs = Tuple[Any, Any, int] _CipherParams = Tuple[int, int, int] @@ -134,27 +141,44 @@ def get_cipher_params(cipher_name: str) -> _CipherParams: _cipher_alg_list = ( - ('aes128-cbc', AES, CBC, 0, 16, 16, 16), - ('aes192-cbc', AES, CBC, 0, 24, 16, 16), - ('aes256-cbc', AES, CBC, 0, 32, 16, 16), - ('aes128-ctr', AES, CTR, 0, 16, 16, 16), - ('aes192-ctr', AES, CTR, 0, 24, 16, 16), - ('aes256-ctr', AES, CTR, 0, 32, 16, 16), - ('aes128-gcm', None, None, 0, 16, 12, 16), - ('aes256-gcm', None, None, 0, 32, 12, 16), - ('arcfour', ARC4, None, 0, 16, 1, 1), - ('arcfour40', ARC4, None, 0, 5, 1, 1), - ('arcfour128', ARC4, None, 1536, 16, 1, 1), - ('arcfour256', ARC4, None, 1536, 32, 1, 1), - ('blowfish-cbc', Blowfish, CBC, 0, 16, 8, 8), - ('cast128-cbc', CAST5, CBC, 0, 16, 8, 8), - ('des-cbc', TripleDES, CBC, 0, 8, 8, 8), - ('des2-cbc', TripleDES, CBC, 0, 16, 8, 8), - ('des3-cbc', TripleDES, CBC, 0, 24, 8, 8), - ('seed-cbc', SEED, CBC, 0, 16, 16, 16) + ('aes128-cbc', 'AES', CBC, 0, 16, 16, 16), + ('aes192-cbc', 'AES', CBC, 0, 24, 16, 16), + ('aes256-cbc', 'AES', CBC, 0, 32, 16, 16), + ('aes128-ctr', 'AES', CTR, 0, 16, 16, 16), + ('aes192-ctr', 'AES', CTR, 0, 24, 16, 16), + ('aes256-ctr', 'AES', CTR, 0, 32, 16, 16), + ('aes128-gcm', None, None, 0, 16, 12, 16), + ('aes256-gcm', None, None, 0, 32, 12, 16), + ('arcfour', 'ARC4', None, 0, 16, 1, 1), + ('arcfour40', 'ARC4', None, 0, 5, 1, 1), + ('arcfour128', 'ARC4', None, 1536, 16, 1, 1), + ('arcfour256', 'ARC4', None, 1536, 32, 1, 1), + ('blowfish-cbc', 'Blowfish', CBC, 0, 16, 8, 8), + ('cast128-cbc', 'CAST5', CBC, 0, 16, 8, 8), + ('des-cbc', 'TripleDES', CBC, 0, 8, 8, 8), + ('des2-cbc', 'TripleDES', CBC, 0, 16, 8, 8), + ('des3-cbc', 'TripleDES', CBC, 0, 24, 8, 8), + ('seed-cbc', 'SEED', CBC, 0, 16, 16, 16) ) -for _cipher_name, _cipher, _mode, _initial_bytes, \ - _key_size, _iv_size, _block_size in _cipher_alg_list: - _cipher_algs[_cipher_name] = (_cipher, _mode, _initial_bytes) - register_cipher(_cipher_name, _key_size, _iv_size, _block_size) +with warnings.catch_warnings(): + warnings.simplefilter('ignore') + + for _cipher_name, _alg, _mode, _initial_bytes, \ + _key_size, _iv_size, _block_size in _cipher_alg_list: + if _alg: + try: + _cipher = getattr(_algs, _alg) + except AttributeError as exc: # pragma: no cover + if _decrepit_algs: + try: + _cipher = getattr(_decrepit_algs, _alg) + except AttributeError: + raise exc from None + else: + raise + else: + _cipher = None + + _cipher_algs[_cipher_name] = (_cipher, _mode, _initial_bytes) + register_cipher(_cipher_name, _key_size, _iv_size, _block_size) diff --git a/asyncssh/crypto/dh.py b/asyncssh/crypto/dh.py new file mode 100644 index 0000000..52e09e7 --- /dev/null +++ b/asyncssh/crypto/dh.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022 by Ron Frederick and others. +# +# This program and the accompanying materials are made available under +# the terms of the Eclipse Public License v2.0 which accompanies this +# distribution and is available at: +# +# http://www.eclipse.org/legal/epl-2.0/ +# +# This program may also be made available under the following secondary +# licenses when the conditions for such availability set forth in the +# Eclipse Public License v2.0 are satisfied: +# +# GNU General Public License, Version 2.0, or any later versions of +# that license +# +# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later +# +# Contributors: +# Ron Frederick - initial implementation, API, and documentation + +"""A shim around PyCA for Diffie Hellman key exchange""" + +from cryptography.hazmat.primitives.asymmetric import dh + + +class DH: + """A shim around PyCA for Diffie Hellman key exchange""" + + def __init__(self, g: int, p: int): + self._pn = dh.DHParameterNumbers(p, g) + self._priv_key = self._pn.parameters().generate_private_key() + + def get_public(self) -> int: + """Return the public key to send in the handshake""" + + pub_key = self._priv_key.public_key() + + return pub_key.public_numbers().y + + def get_shared(self, peer_public: int) -> int: + """Return the shared key from the peer's public key""" + + peer_key = dh.DHPublicNumbers(peer_public, self._pn).public_key() + shared_key = self._priv_key.exchange(peer_key) + + return int.from_bytes(shared_key, 'big') diff --git a/asyncssh/crypto/dsa.py b/asyncssh/crypto/dsa.py index 8f3bd32..befb544 100644 --- a/asyncssh/crypto/dsa.py +++ b/asyncssh/crypto/dsa.py @@ -1,4 +1,4 @@ -# Copyright (c) 2014-2021 by Ron Frederick and others. +# Copyright (c) 2014-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -36,7 +36,8 @@ class _DSAKey(CryptoKey): """Base class for shim around PyCA for DSA keys""" def __init__(self, pyca_key: PyCAKey, params: dsa.DSAParameterNumbers, - pub: dsa.DSAPublicNumbers, priv: dsa.DSAPrivateNumbers = None): + pub: dsa.DSAPublicNumbers, + priv: Optional[dsa.DSAPrivateNumbers] = None): super().__init__(pyca_key) self._params = params diff --git a/asyncssh/crypto/ec.py b/asyncssh/crypto/ec.py index 8928378..31a9330 100644 --- a/asyncssh/crypto/ec.py +++ b/asyncssh/crypto/ec.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2021 by Ron Frederick and others. +# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -61,8 +61,7 @@ def lookup_curve(cls, curve_id: bytes) -> Type[ec.EllipticCurve]: try: return _curves[curve_id] except KeyError: # pragma: no cover, other curves not registered - raise ValueError('Unknown EC curve %s' % - curve_id.decode()) from None + raise ValueError(f'Unknown EC curve {curve_id.decode()}') from None @property def curve_id(self) -> bytes: @@ -181,8 +180,7 @@ def __init__(self, curve_id: bytes): try: curve = _curves[curve_id] except KeyError: # pragma: no cover, other curves not registered - raise ValueError('Unknown EC curve %s' % - curve_id.decode()) from None + raise ValueError(f'Unknown EC curve {curve_id.decode()}') from None self._priv_key = ec.generate_private_key(curve()) @@ -194,12 +192,15 @@ def get_public(self) -> bytes: return pub_key.public_bytes(Encoding.X962, PublicFormat.UncompressedPoint) - def get_shared(self, peer_public: bytes) -> int: - """Return the shared key from the peer's public key""" + def get_shared_bytes(self, peer_public: bytes) -> bytes: + """Return the shared key from the peer's public key as bytes""" peer_key = ec.EllipticCurvePublicKey.from_encoded_point( self._priv_key.curve, peer_public) - shared_key = self._priv_key.exchange(ec.ECDH(), peer_key) + return self._priv_key.exchange(ec.ECDH(), peer_key) + + def get_shared(self, peer_public: bytes) -> int: + """Return the shared key from the peer's public key""" - return int.from_bytes(shared_key, 'big') + return int.from_bytes(self.get_shared_bytes(peer_public), 'big') diff --git a/asyncssh/crypto/ed.py b/asyncssh/crypto/ed.py index e239beb..d9268cf 100644 --- a/asyncssh/crypto/ed.py +++ b/asyncssh/crypto/ed.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021 by Ron Frederick and others. +# Copyright (c) 2019-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -50,7 +50,8 @@ class _EdDSAKey(CryptoKey): """Base class for shim around PyCA for EdDSA keys""" - def __init__(self, pyca_key: PyCAKey, pub: bytes, priv: bytes = None): + def __init__(self, pyca_key: PyCAKey, pub: bytes, + priv: Optional[bytes] = None): super().__init__(pyca_key) self._pub = pub @@ -146,7 +147,7 @@ def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool: class _EdDSANaclKey: """Base class for shim around libnacl for EdDSA keys""" - def __init__(self, pub: bytes, priv: bytes = None): + def __init__(self, pub: bytes, priv: Optional[bytes] = None): self._pub = pub self._priv = priv @@ -243,12 +244,16 @@ def get_public(self) -> bytes: return self._priv_key.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw) + def get_shared_bytes(self, peer_public: bytes) -> bytes: + """Return the shared key from the peer's public key as bytes""" + + peer_key = x25519.X25519PublicKey.from_public_bytes(peer_public) + return self._priv_key.exchange(peer_key) + def get_shared(self, peer_public: bytes) -> int: """Return the shared key from the peer's public key""" - peer_key = x25519.X25519PublicKey.from_public_bytes(peer_public) - shared = self._priv_key.exchange(peer_key) - return int.from_bytes(shared, 'big') + return int.from_bytes(self.get_shared_bytes(peer_public), 'big') else: # pragma: no cover class Curve25519DH: # type: ignore """Curve25519 Diffie Hellman implementation based on libnacl""" @@ -267,8 +272,8 @@ def get_public(self) -> bytes: return public.raw - def get_shared(self, peer_public: bytes) -> int: - """Return the shared key from the peer's public key""" + def get_shared_bytes(self, peer_public: bytes) -> bytes: + """Return the shared key from the peer's public key as bytes""" if len(peer_public) != _CURVE25519_BYTES: raise ValueError('Invalid curve25519 public key size') @@ -278,7 +283,12 @@ def get_shared(self, peer_public: bytes) -> int: if _curve25519(shared, self._private, peer_public) != 0: raise ValueError('Curve25519 failed') - return int.from_bytes(shared.raw, 'big') + return shared.raw + + def get_shared(self, peer_public: bytes) -> int: + """Return the shared key from the peer's public key""" + + return int.from_bytes(self.get_shared_bytes(peer_public), 'big') try: from libnacl import nacl diff --git a/asyncssh/crypto/misc.py b/asyncssh/crypto/misc.py index 1cc5455..4fc2ce2 100644 --- a/asyncssh/crypto/misc.py +++ b/asyncssh/crypto/misc.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2021 by Ron Frederick and others. +# Copyright (c) 2017-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -60,5 +60,11 @@ def pyca_key(self) -> PyCAKey: def sign(self, data: bytes, hash_name: str = '') -> bytes: """Sign a block of data""" + # pylint: disable=no-self-use + raise RuntimeError # pragma: no cover + def verify(self, data: bytes, sig: bytes, hash_name: str = '') -> bool: """Verify the signature on a block of data""" + + # pylint: disable=no-self-use + raise RuntimeError # pragma: no cover diff --git a/asyncssh/crypto/pq.py b/asyncssh/crypto/pq.py new file mode 100644 index 0000000..bf24503 --- /dev/null +++ b/asyncssh/crypto/pq.py @@ -0,0 +1,103 @@ +# Copyright (c) 2022-2024 by Ron Frederick and others. +# +# This program and the accompanying materials are made available under +# the terms of the Eclipse Public License v2.0 which accompanies this +# distribution and is available at: +# +# http://www.eclipse.org/legal/epl-2.0/ +# +# This program may also be made available under the following secondary +# licenses when the conditions for such availability set forth in the +# Eclipse Public License v2.0 are satisfied: +# +# GNU General Public License, Version 2.0, or any later versions of +# that license +# +# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later +# +# Contributors: +# Ron Frederick - initial implementation, API, and documentation + +"""A shim around liboqs for Streamlined NTRU Prime post-quantum encryption""" + +import ctypes +import ctypes.util +from typing import Mapping, Tuple + + +_pq_algs: Mapping[bytes, Tuple[int, int, int, int, str]] = { + b'mlkem768': (1184, 2400, 1088, 32, 'KEM_ml_kem_768'), + b'mlkem1024': (1568, 3168, 1568, 32, 'KEM_ml_kem_1024'), + b'sntrup761': (1158, 1763, 1039, 32, 'KEM_ntruprime_sntrup761') +} + +mlkem_available = False +sntrup_available = False + +for lib in ('oqs', 'liboqs'): + _oqs_lib = ctypes.util.find_library(lib) + + if _oqs_lib: # pragma: no branch + break +else: # pragma: no cover + _oqs_lib = None + +if _oqs_lib: # pragma: no branch + _oqs = ctypes.cdll.LoadLibrary(_oqs_lib) + + mlkem_available = (hasattr(_oqs, 'OQS_KEM_ml_kem_768_keypair') or + hasattr(_oqs, 'OQS_KEM_ml_kem_768_ipd_keypair')) + sntrup_available = hasattr(_oqs, 'OQS_KEM_ntruprime_sntrup761_keypair') + + +class PQDH: + """A shim around liboqs for post-quantum key exchange algorithms""" + + def __init__(self, alg_name: bytes): + try: + self.pubkey_bytes, self.privkey_bytes, \ + self.ciphertext_bytes, self.secret_bytes, \ + oqs_name = _pq_algs[alg_name] + except KeyError: # pragma: no cover, other algs not registered + raise ValueError(f'Unknown PQ algorithm {oqs_name}') from None + + if not hasattr(_oqs, 'OQS_' + oqs_name + '_keypair'): # pragma: no cover + oqs_name += '_ipd' + + self._keypair = getattr(_oqs, 'OQS_' + oqs_name + '_keypair') + self._encaps = getattr(_oqs, 'OQS_' + oqs_name + '_encaps') + self._decaps = getattr(_oqs, 'OQS_' + oqs_name + '_decaps') + + def keypair(self) -> Tuple[bytes, bytes]: + """Make a new key pair""" + + pubkey = ctypes.create_string_buffer(self.pubkey_bytes) + privkey = ctypes.create_string_buffer(self.privkey_bytes) + self._keypair(pubkey, privkey) + + return pubkey.raw, privkey.raw + + def encaps(self, pubkey: bytes) -> Tuple[bytes, bytes]: + """Generate a random secret and encrypt it with a public key""" + + if len(pubkey) != self.pubkey_bytes: + raise ValueError('Invalid public key') + + ciphertext = ctypes.create_string_buffer(self.ciphertext_bytes) + secret = ctypes.create_string_buffer(self.secret_bytes) + + self._encaps(ciphertext, secret, pubkey) + + return secret.raw, ciphertext.raw + + def decaps(self, ciphertext: bytes, privkey: bytes) -> bytes: + """Decrypt an encrypted secret using a private key""" + + if len(ciphertext) != self.ciphertext_bytes: + raise ValueError('Invalid ciphertext') + + secret = ctypes.create_string_buffer(self.secret_bytes) + + self._decaps(secret, ciphertext, privkey) + + return secret.raw diff --git a/asyncssh/crypto/rsa.py b/asyncssh/crypto/rsa.py index 989666a..66a2774 100644 --- a/asyncssh/crypto/rsa.py +++ b/asyncssh/crypto/rsa.py @@ -1,4 +1,4 @@ -# Copyright (c) 2014-2021 by Ron Frederick and others. +# Copyright (c) 2014-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -38,7 +38,7 @@ class _RSAKey(CryptoKey): """Base class for shim around PyCA for RSA keys""" def __init__(self, pyca_key: PyCAKey, pub: rsa.RSAPublicNumbers, - priv: rsa.RSAPrivateNumbers = None): + priv: Optional[rsa.RSAPrivateNumbers] = None): super().__init__(pyca_key) self._pub = pub @@ -98,12 +98,14 @@ class RSAPrivateKey(_RSAKey): @classmethod def construct(cls, n: int, e: int, d: int, p: int, q: int, - dmp1: int, dmq1: int, iqmp: int) -> 'RSAPrivateKey': + dmp1: int, dmq1: int, iqmp: int, + skip_validation: bool) -> 'RSAPrivateKey': """Construct an RSA private key""" pub = rsa.RSAPublicNumbers(e, n) priv = rsa.RSAPrivateNumbers(p, q, d, dmp1, dmq1, iqmp, pub) - priv_key = priv.private_key() + priv_key = priv.private_key( + unsafe_skip_rsa_key_validation=skip_validation) return cls(priv_key, pub, priv) diff --git a/asyncssh/crypto/umac.py b/asyncssh/crypto/umac.py index 2e44f35..981be41 100644 --- a/asyncssh/crypto/umac.py +++ b/asyncssh/crypto/umac.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016-2021 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -23,7 +23,6 @@ import binascii import ctypes import ctypes.util -import sys from typing import TYPE_CHECKING, Callable, Optional @@ -43,8 +42,8 @@ def _build_umac(size: int) -> '_New': """Function to build UMAC wrapper for a specific digest size""" - _name = 'umac%d' % size - _prefix = 'nettle_%s_' % _name + _name = f'umac{size}' + _prefix = f'nettle_{_name}_' try: _context_size: int = getattr(_nettle, _prefix + '_ctx_size')() @@ -126,8 +125,13 @@ def hexdigest(self) -> str: return _UMAC.new -_nettle_lib = 'libnettle-6' if sys.platform == 'win32' \ - else ctypes.util.find_library('nettle') +for lib in ('nettle', 'libnettle', 'libnettle-6'): + _nettle_lib = ctypes.util.find_library(lib) + + if _nettle_lib: # pragma: no branch + break +else: # pragma: no cover + _nettle_lib = None if _nettle_lib: # pragma: no branch _nettle = ctypes.cdll.LoadLibrary(_nettle_lib) diff --git a/asyncssh/crypto/x509.py b/asyncssh/crypto/x509.py index 560b993..14f91ae 100644 --- a/asyncssh/crypto/x509.py +++ b/asyncssh/crypto/x509.py @@ -55,11 +55,9 @@ _nscomment_oid = x509.ObjectIdentifier('2.16.840.1.113730.1.13') -_datetime_min = datetime.utcfromtimestamp(0).replace(microsecond=1, - tzinfo=timezone.utc) +_datetime_min = datetime.fromtimestamp(0, timezone.utc).replace(microsecond=1) -_datetime_32bit_max = datetime.utcfromtimestamp(2**31 - 1).replace( - tzinfo=timezone.utc) +_datetime_32bit_max = datetime.fromtimestamp(2**31 - 1, timezone.utc) if sys.platform == 'win32': # pragma: no cover # Windows' datetime.max is year 9999, but timestamps that large don't work @@ -75,12 +73,13 @@ def _to_generalized_time(t: int) -> datetime: return _datetime_min else: try: - return datetime.utcfromtimestamp(t).replace(tzinfo=timezone.utc) + return datetime.fromtimestamp(t, timezone.utc) except (OSError, OverflowError): try: # Work around a bug in cryptography which shows up on # systems with a small time_t. - datetime.utcfromtimestamp(_datetime_max.timestamp() - 1) + datetime.fromtimestamp(_datetime_max.timestamp() - 1, + timezone.utc) return _datetime_max except (OSError, OverflowError): # pragma: no cover return _datetime_32bit_max @@ -95,8 +94,8 @@ def _to_purpose_oids(purposes: _Purposes) -> _PurposeOIDs: if not purposes or 'any' in purposes or _purpose_any in purposes: purpose_oids = None else: - purpose_oids = set(_purpose_to_oid.get(p) or x509.ObjectIdentifier(p) - for p in purposes) + purpose_oids = {_purpose_to_oid.get(p) or x509.ObjectIdentifier(p) + for p in purposes} return purpose_oids @@ -144,8 +143,8 @@ class X509Name(x509.Name): ('CN', x509.NameOID.COMMON_NAME), ('DC', x509.NameOID.DOMAIN_COMPONENT)) - _to_oid = dict((k, v) for k, v in _attrs) - _from_oid = dict((v, k) for k, v in _attrs) + _to_oid = dict(_attrs) + _from_oid = {v: k for k, v in _attrs} def __init__(self, name: _NameInit): if isinstance(name, str): @@ -169,7 +168,7 @@ def _format_attr(self, nameattr: x509.NameAttribute) -> str: """Format an X.509 NameAttribute as a string""" attr = self._from_oid.get(nameattr.oid) or nameattr.oid.dotted_string - return attr + '=' + self._escape.sub(r'\\\1', nameattr.value) + return attr + '=' + self._escape.sub(r'\\\1', cast(str, nameattr.value)) def _parse_name(self, name: str) -> \ Iterable[x509.RelativeDistinguishedName]: @@ -261,7 +260,7 @@ def __init__(self, cert: x509.Certificate, data: bytes): [str(ip) for ip in sans.get_values_for_type(x509.IPAddress)] except x509.ExtensionNotFound: cn = cert.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME) - principals = [attr.value for attr in cn] + principals = [cast(str, attr.value) for attr in cn] self.user_principals = principals self.host_principals = principals @@ -309,7 +308,7 @@ def validate(self, trust_store: Sequence['X509Certificate'], None) x509_ctx.verify_certificate() except crypto.X509StoreContextError as exc: - raise ValueError(str(exc)) from None + raise ValueError(f'X.509 chain validation error: {exc}') from None def generate_x509_certificate(signing_key: PyCAKey, key: PyCAKey, @@ -405,7 +404,8 @@ def generate_x509_certificate(signing_key: PyCAKey, key: PyCAKey, except KeyError: raise ValueError('Unknown hash algorithm') from None - cert = builder.sign(cast(PyCAPrivateKey, signing_key), hash_alg) + cert = builder.sign(cast(PyCAPrivateKey, signing_key), + hash_alg) # type: ignore data = cert.public_bytes(Encoding.DER) return X509Certificate(cert, data) diff --git a/asyncssh/dsa.py b/asyncssh/dsa.py index ee9de4a..e312aa8 100644 --- a/asyncssh/dsa.py +++ b/asyncssh/dsa.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -41,7 +41,7 @@ class _DSAKey(SSHKey): _key: Union[DSAPrivateKey, DSAPublicKey] algorithm = b'ssh-dss' - default_hash_name = 'sha1' + default_x509_hash = 'sha256' pem_name = b'DSA' pkcs8_oid = ObjectIdentifier('1.2.840.10040.4.1') sig_algorithms = (algorithm,) @@ -123,7 +123,7 @@ def decode_pkcs8_private(cls, alg_params: object, if (isinstance(alg_params, tuple) and len(alg_params) == 3 and all_ints(alg_params) and isinstance(x, int)): p, q, g = alg_params - y = pow(g, x, p) + y: int = pow(g, x, p) return p, q, g, y, x else: return None diff --git a/asyncssh/ecdsa.py b/asyncssh/ecdsa.py index 25bad39..9429deb 100644 --- a/asyncssh/ecdsa.py +++ b/asyncssh/ecdsa.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -54,7 +54,7 @@ class _ECKey(SSHKey): _key: Union[ECDSAPrivateKey, ECDSAPublicKey] - default_hash_name = 'sha256' + default_x509_hash = 'sha256' pem_name = b'EC' pkcs8_oid = ObjectIdentifier('1.2.840.10045.2.1') @@ -91,8 +91,8 @@ def _lookup_curve(cls, alg_params: object) -> bytes: try: curve_id = _alg_oid_map[alg_params] except KeyError: - raise KeyImportError('Unknown elliptic curve OID %s' % - alg_params) from None + raise KeyImportError('Unknown elliptic curve OID ' + f'{alg_params}') from None elif (isinstance(alg_params, tuple) and len(alg_params) >= 5 and alg_params[0] == 1 and isinstance(alg_params[1], tuple) and len(alg_params[1]) == 2 and alg_params[1][0] == PRIME_FIELD and diff --git a/asyncssh/editor.py b/asyncssh/editor.py index 9b8b2f6..c202357 100644 --- a/asyncssh/editor.py +++ b/asyncssh/editor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016-2022 by Ron Frederick and others. +# Copyright (c) 2016-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -388,7 +388,7 @@ def _update_input(self, pos: int, column: int, new_pos: int) -> None: self._pos = new_pos - def _reset_line(self): + def _reset_line(self) -> None: """Reset input line to empty""" self._line = '' @@ -398,7 +398,7 @@ def _reset_line(self): self._start_column = self._cursor self._end_column = self._cursor - def _reset_pending(self): + def _reset_pending(self) -> None: """Reset a pending echoed line if any""" if self._line_pending: @@ -704,8 +704,10 @@ def process_input(self, data: str, datatype: DataType) -> None: self._ring_bell() self._bell_rung = False - self._chan.write(''.join(self._outbuf)) - self._outbuf.clear() + + if self._outbuf: + self._chan.write(''.join(self._outbuf)) + self._outbuf.clear() else: self._session.data_received(data, datatype) @@ -745,7 +747,7 @@ class SSHLineEditorChannel: this class is wrapped around the channel, providing the caller with the ability to enable and disable input line editing and echoing. - .. note:: Line editing is only available when a psuedo-terminal + .. note:: Line editing is only available when a pseudo-terminal is requested on the server channel and the character encoding on the channel is not set to `None`. diff --git a/asyncssh/encryption.py b/asyncssh/encryption.py index 02b0497..3e2e99b 100644 --- a/asyncssh/encryption.py +++ b/asyncssh/encryption.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -48,7 +48,7 @@ def new(cls, cipher_name: str, key: bytes, iv: bytes, mac_alg: bytes = b'', @classmethod def get_mac_params(cls, mac_alg: bytes) -> Tuple[int, int, bool]: - """Get paramaters of the MAC algorithm used with this encryption""" + """Get parameters of the MAC algorithm used with this encryption""" return get_mac_params(mac_alg) @@ -163,7 +163,7 @@ def new(cls, cipher_name: str, key: bytes, iv: bytes, mac_alg: bytes = b'', @classmethod def get_mac_params(cls, mac_alg: bytes) -> Tuple[int, int, bool]: - """Get paramaters of the MAC algorithm used with this encryption""" + """Get parameters of the MAC algorithm used with this encryption""" return 0, 16, True @@ -202,7 +202,7 @@ def new(cls, cipher_name: str, key: bytes, iv: bytes, mac_alg: bytes = b'', @classmethod def get_mac_params(cls, mac_alg: bytes) -> Tuple[int, int, bool]: - """Get paramaters of the MAC algorithm used with this encryption""" + """Get parameters of the MAC algorithm used with this encryption""" return 0, 16, True diff --git a/asyncssh/forward.py b/asyncssh/forward.py index 784db1d..db6e7cf 100644 --- a/asyncssh/forward.py +++ b/asyncssh/forward.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -22,7 +22,10 @@ import asyncio import socket -from typing import TYPE_CHECKING, Awaitable, Callable, Optional, cast +from types import TracebackType +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional +from typing import Type, cast +from typing_extensions import Self from .misc import ChannelOpenError, SockAddr @@ -38,7 +41,8 @@ class SSHForwarder(asyncio.BaseProtocol): """SSH port forwarding connection handler""" - def __init__(self, peer: Optional['SSHForwarder'] = None): + def __init__(self, peer: Optional['SSHForwarder'] = None, + extra: Optional[Dict[str, Any]] = None): self._peer = peer self._transport: Optional[asyncio.Transport] = None self._inpbuf = b'' @@ -47,6 +51,32 @@ def __init__(self, peer: Optional['SSHForwarder'] = None): if peer: peer.set_peer(self) + if extra is None: + extra = {} + + self._extra = extra + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, _exc_type: Optional[Type[BaseException]], + _exc_value: Optional[BaseException], + _traceback: Optional[TracebackType]) -> bool: + self.close() + return False + + def get_extra_info(self, name: str, default: Any = None) -> Any: + """Get additional information about the forwarder + + This method returns extra information about the forwarder. + Currently, the only information available is the value + ``interface`` for TUN/TAP forwarders, returning the name of the + local TUN/TAP network interface created for this forwarder. + + """ + + return self._extra.get(name, default) + def set_peer(self, peer: 'SSHForwarder') -> None: """Set the peer forwarder to exchange data with""" @@ -61,7 +91,8 @@ def write(self, data: bytes) -> None: def write_eof(self) -> None: """Write end of file to the transport""" - assert self._transport is not None + if not self._transport: + return # pragma: no cover try: self._transport.write_eof() @@ -106,7 +137,7 @@ def session_started(self) -> None: """Handle session start""" def data_received(self, data: bytes, - datatype: int = None) -> None: + datatype: Optional[int] = None) -> None: """Handle incoming data from the transport""" # pylint: disable=unused-argument diff --git a/asyncssh/gss.py b/asyncssh/gss.py index 110b6fc..1a4b45e 100644 --- a/asyncssh/gss.py +++ b/asyncssh/gss.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2021 by Ron Frederick and others. +# Copyright (c) 2017-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -22,6 +22,11 @@ import sys +from typing import Optional + +from .misc import BytesOrStrDict + + try: # pylint: disable=unused-import @@ -37,8 +42,8 @@ class GSSError(ValueError): # type: ignore """Stub class for reporting that GSS is not available""" - def __init__(self, maj_code: int = 0, min_code: int = 0, - token: bytes = None): + def __init__(self, maj_code: int, min_code: int, + token: Optional[bytes] = None): super().__init__('GSS not available') self.maj_code = maj_code @@ -51,11 +56,12 @@ class GSSBase: # type: ignore class GSSClient(GSSBase): # type: ignore """Stub client class for reporting that GSS is not available""" - def __init__(self, _host: str, _delegate_creds: bool): - raise GSSError() + def __init__(self, _host: str, _store: Optional[BytesOrStrDict], + _delegate_creds: bool): + raise GSSError(0, 0) class GSSServer(GSSBase): # type: ignore """Stub client class for reporting that GSS is not available""" - def __init__(self, _host: str): - raise GSSError() + def __init__(self, _host: str, _store: Optional[BytesOrStrDict]): + raise GSSError(0, 0) diff --git a/asyncssh/gss_unix.py b/asyncssh/gss_unix.py index bc34bad..145dec2 100644 --- a/asyncssh/gss_unix.py +++ b/asyncssh/gss_unix.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2021 by Ron Frederick and others. +# Copyright (c) 2017-2022 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -20,39 +20,43 @@ """GSSAPI wrapper for UNIX""" -from typing import Optional, Sequence +from typing import Optional, Sequence, SupportsBytes, cast from gssapi import Credentials, Name, NameType, OID from gssapi import RequirementFlag, SecurityContext from gssapi.exceptions import GSSError from .asn1 import OBJECT_IDENTIFIER +from .misc import BytesOrStrDict def _mech_to_oid(mech: OID) -> bytes: """Return a DER-encoded OID corresponding to the requested GSS mechanism""" - mech_bytes = bytes(mech) + mech_bytes = bytes(cast(SupportsBytes, mech)) return bytes((OBJECT_IDENTIFIER, len(mech_bytes))) + mech_bytes class GSSBase: """GSS base class""" - def __init__(self, host: str, usage: str): + def __init__(self, host: str, store: Optional[BytesOrStrDict]): if '@' in host: self._host = Name(host) else: self._host = Name('host@' + host, NameType.hostbased_service) - if usage == 'initiate': - self._creds = Credentials(usage=usage) - else: - self._creds = Credentials(name=self._host, usage=usage) + self._store = store self._mechs = [_mech_to_oid(mech) for mech in self._creds.mechs] self._ctx: Optional[SecurityContext] = None + @property + def _creds(self) -> Credentials: + """Abstract method to construct GSS credentials""" + + raise NotImplementedError + def _init_context(self) -> None: """Abstract method to construct GSS security context""" @@ -76,8 +80,8 @@ def provides_mutual_auth(self) -> bool: assert self._ctx is not None - return (RequirementFlag.mutual_authentication in - self._ctx.actual_flags) + return bool(self._ctx.actual_flags & + RequirementFlag.mutual_authentication) @property def provides_integrity(self) -> bool: @@ -85,7 +89,7 @@ def provides_integrity(self) -> bool: assert self._ctx is not None - return RequirementFlag.integrity in self._ctx.actual_flags + return bool(self._ctx.actual_flags & RequirementFlag.integrity) @property def user(self) -> str: @@ -140,17 +144,24 @@ def verify(self, data: bytes, sig: bytes) -> bool: class GSSClient(GSSBase): """GSS client""" - def __init__(self, host: str, delegate_creds: bool): - super().__init__(host, 'initiate') + def __init__(self, host: str, store: Optional[BytesOrStrDict], + delegate_creds: bool): + super().__init__(host, store) - flags = set((RequirementFlag.mutual_authentication, - RequirementFlag.integrity)) + flags = RequirementFlag.mutual_authentication | \ + RequirementFlag.integrity if delegate_creds: - flags.add(RequirementFlag.delegate_to_peer) + flags |= RequirementFlag.delegate_to_peer self._flags = flags + @property + def _creds(self) -> Credentials: + """Abstract method to construct GSS credentials""" + + return Credentials(usage='initiate', store=self._store) + def _init_context(self) -> None: """Construct GSS client security context""" @@ -161,8 +172,11 @@ def _init_context(self) -> None: class GSSServer(GSSBase): """GSS server""" - def __init__(self, host: str): - super().__init__(host, 'accept') + @property + def _creds(self) -> Credentials: + """Abstract method to construct GSS credentials""" + + return Credentials(name=self._host, usage='accept', store=self._store) def _init_context(self) -> None: """Construct GSS server security context""" diff --git a/asyncssh/gss_win32.py b/asyncssh/gss_win32.py index f9a6897..0159762 100644 --- a/asyncssh/gss_win32.py +++ b/asyncssh/gss_win32.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2022 by Ron Frederick and others. +# Copyright (c) 2017-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -35,6 +35,7 @@ from sspicon import SECPKG_ATTR_NATIVE_NAMES from .asn1 import ObjectIdentifier, der_encode +from .misc import BytesOrStrDict _krb5_oid = der_encode(ObjectIdentifier('1.2.840.113554.1.2.2')) @@ -135,7 +136,7 @@ def sign(self, data: bytes) -> bytes: try: return self._ctx.sign(data) - except SSPIError as exc: + except SSPIError as exc: # pragna: no cover raise GSSError(details=exc.strerror) from None def verify(self, data: bytes, sig: bytes) -> bool: @@ -156,7 +157,11 @@ class GSSClient(GSSBase): _mutual_auth_flag = ISC_RET_MUTUAL_AUTH _integrity_flag = ISC_RET_INTEGRITY - def __init__(self, host: str, delegate_creds: bool): + def __init__(self, host: str, store: Optional[BytesOrStrDict], + delegate_creds: bool): + if store is not None: # pragna: no cover + raise GSSError(details='GSS store not supported on Windows') + super().__init__(host) flags = ISC_REQ_MUTUAL_AUTH | ISC_REQ_INTEGRITY @@ -167,7 +172,7 @@ def __init__(self, host: str, delegate_creds: bool): try: self._ctx = ClientAuth('Kerberos', targetspn=self._host, scflags=flags) - except SSPIError as exc: + except SSPIError as exc: # pragna: no cover raise GSSError(1, 1, details=exc.strerror) from None self._init_token = self.step(None) @@ -179,7 +184,10 @@ class GSSServer(GSSBase): _mutual_auth_flag = ASC_RET_MUTUAL_AUTH _integrity_flag = ASC_RET_INTEGRITY - def __init__(self, host: str): + def __init__(self, host: str, store: Optional[BytesOrStrDict]): + if store is not None: # pragna: no cover + raise GSSError(details='GSS store not supported on Windows') + super().__init__(host) flags = ASC_REQ_MUTUAL_AUTH | ASC_REQ_INTEGRITY @@ -194,7 +202,7 @@ class GSSError(Exception): """Class for reporting GSS errors""" def __init__(self, maj_code: int = 0, min_code: int = 0, - token: bytes = None, details: str = ''): + token: Optional[bytes] = None, details: str = ''): super().__init__(details) self.maj_code = maj_code diff --git a/asyncssh/kex.py b/asyncssh/kex.py index 4560475..1458540 100644 --- a/asyncssh/kex.py +++ b/asyncssh/kex.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -26,7 +26,7 @@ from .logging import SSHLogger from .misc import HashType -from .packet import MPInt, SSHPacketHandler +from .packet import SSHPacketHandler if TYPE_CHECKING: @@ -58,7 +58,7 @@ def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType): self._hash_alg = hash_alg - def start(self) -> None: + async def start(self) -> None: """Start key exchange""" raise NotImplementedError @@ -74,14 +74,14 @@ def logger(self) -> SSHLogger: return self._logger - def compute_key(self, k: int, h: bytes, x: bytes, + def compute_key(self, k: bytes, h: bytes, x: bytes, session_id: bytes, keylen: int) -> bytes: """Compute keys from output of key exchange""" key = b'' while len(key) < keylen: hash_obj = self._hash_alg() - hash_obj.update(MPInt(k)) + hash_obj.update(k) hash_obj.update(h) hash_obj.update(key if key else x + session_id) key += hash_obj.digest() diff --git a/asyncssh/kex_dh.py b/asyncssh/kex_dh.py index 1358a06..fbcbdb4 100644 --- a/asyncssh/kex_dh.py +++ b/asyncssh/kex_dh.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -25,12 +25,13 @@ from typing_extensions import Protocol from .constants import DEFAULT_LANG +from .crypto import Curve25519DH, Curve448DH, DH, ECDH, PQDH from .crypto import curve25519_available, curve448_available -from .crypto import Curve25519DH, Curve448DH, ECDH +from .crypto import mlkem_available, sntrup_available from .gss import GSSError from .kex import Kex, register_kex_alg, register_gss_kex_alg from .misc import HashType, KeyExchangeFailed, ProtocolError -from .misc import get_symbol_names, randrange +from .misc import get_symbol_names, run_in_executor from .packet import Boolean, MPInt, String, UInt32, SSHPacket from .public_key import SigningKey, VerifyingKey @@ -47,6 +48,9 @@ class DHKey(Protocol): def get_public(self) -> bytes: """Return the public key to send to the peer""" + def get_shared_bytes(self, peer_public: bytes) -> bytes: + """Return the shared key from the peer's public key in bytes""" + def get_shared(self, peer_public: bytes) -> int: """Return the shared key from the peer's public key""" @@ -128,10 +132,9 @@ class _KexDHBase(Kex): def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType): super().__init__(alg, conn, hash_alg) + self._dh: Optional[DH] = None self._g = 0 self._p = 0 - self._q = 0 - self._x = 0 self._e = 0 self._f = 0 self._gex_data = b'' @@ -141,9 +144,8 @@ def _init_group(self, g: int, p: int) -> None: self._g = g self._p = p - self._q = (p - 1) // 2 - def _compute_hash(self, host_key_data: bytes, k: int) -> bytes: + def _compute_hash(self, host_key_data: bytes, k: bytes) -> bytes: """Compute a hash of key information associated with the connection""" hash_obj = self._hash_alg() @@ -152,7 +154,7 @@ def _compute_hash(self, host_key_data: bytes, k: int) -> bytes: hash_obj.update(self._gex_data) hash_obj.update(self._format_client_key()) hash_obj.update(self._format_server_key()) - hash_obj.update(MPInt(k)) + hash_obj.update(k) return hash_obj.digest() def _parse_client_key(self, packet: SSHPacket) -> None: @@ -195,42 +197,33 @@ def _send_reply(self, key_data: bytes, sig: bytes) -> None: def _perform_init(self) -> None: """Compute e and send init message""" - self._x = randrange(2, self._q) - self._e = pow(self._g, self._x, self._p) + self._dh = DH(self._g, self._p) + self._e = self._dh.get_public() self._send_init() - def _compute_client_shared(self) -> int: + def _compute_client_shared(self) -> bytes: """Compute client shared key""" if not 1 <= self._f < self._p: raise ProtocolError('Kex DH f out of range') - k = pow(self._f, self._x, self._p) - - if k < 1: # pragma: no cover, shouldn't be possible with valid p - raise ProtocolError('Kex DH k out of range') - - return k + assert self._dh is not None + return MPInt(self._dh.get_shared(self._f)) - def _compute_server_shared(self) -> int: + def _compute_server_shared(self) -> bytes: """Compute server shared key""" if not 1 <= self._e < self._p: raise ProtocolError('Kex DH e out of range') - y = randrange(2, self._q) - self._f = pow(self._g, y, self._p) + self._dh = DH(self._g, self._p) + self._f = self._dh.get_public() - k = pow(self._e, y, self._p) - - if k < 1: # pragma: no cover, shouldn't be possible with valid p - raise ProtocolError('Kex DH k out of range') - - return k + return MPInt(self._dh.get_shared(self._e)) def _perform_reply(self, key: SigningKey, key_data: bytes) -> None: - """Compute f and send reply message""" + """Compute server shared key and send reply message""" k = self._compute_server_shared() h = self._compute_hash(key_data, k) @@ -282,7 +275,7 @@ def _process_reply(self, _pkttype: int, _pktid: int, host_key = client_conn.validate_server_host_key(host_key_data) self._verify_reply(host_key, host_key_data, sig) - def start(self) -> None: + async def start(self) -> None: """Start DH key exchange""" if self._conn.is_client(): @@ -350,6 +343,9 @@ def _process_request(self, pkttype: int, _pktid: int, if self._conn.is_client(): raise ProtocolError('Unexpected kex request msg') + if self._p: + raise ProtocolError('Kex DH group already requested') + self._gex_data = packet.get_remaining_payload() if pkttype == MSG_KEX_DH_GEX_REQUEST_OLD: @@ -384,6 +380,9 @@ def _process_group(self, _pkttype: int, _pktid: int, if self._conn.is_server(): raise ProtocolError('Unexpected kex group msg') + if self._p: + raise ProtocolError('Kex DH group already sent') + p = packet.get_mpint() g = packet.get_mpint() packet.check_end() @@ -392,7 +391,7 @@ def _process_group(self, _pkttype: int, _pktid: int, self._gex_data += MPInt(p) + MPInt(g) self._perform_init() - def start(self) -> None: + async def start(self) -> None: """Start DH group exchange""" if self._conn.is_client(): @@ -447,23 +446,23 @@ def _format_server_key(self) -> bytes: return String(self._server_pub) - def _compute_client_shared(self) -> int: + def _compute_client_shared(self) -> bytes: """Compute client shared key""" try: - return self._priv.get_shared(self._server_pub) + return MPInt(self._priv.get_shared(self._server_pub)) except ValueError: raise ProtocolError('Invalid ECDH server public key') from None - def _compute_server_shared(self) -> int: + def _compute_server_shared(self) -> bytes: """Compute server shared key""" try: - return self._priv.get_shared(self._client_pub) + return MPInt(self._priv.get_shared(self._client_pub)) except ValueError: raise ProtocolError('Invalid ECDH client public key') from None - def start(self) -> None: + async def start(self) -> None: """Start ECDH key exchange""" if self._conn.is_client(): @@ -475,6 +474,58 @@ def start(self) -> None: } +class _KexHybridECDH(_KexECDH): + """Handler for post-quantum key exchange""" + + def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType, + pq_alg_name: bytes, ecdh_class: _ECDHClass, *args: object): + super().__init__(alg, conn, hash_alg, ecdh_class, *args) + + self._pq = PQDH(pq_alg_name) + + if conn.is_client(): + pq_pub, self._pq_priv = self._pq.keypair() + self._client_pub = pq_pub + self._client_pub + + def _compute_client_shared(self) -> bytes: + """Compute client shared key""" + + pq_ciphertext = self._server_pub[:self._pq.ciphertext_bytes] + ec_pub = self._server_pub[self._pq.ciphertext_bytes:] + + try: + pq_secret = self._pq.decaps(pq_ciphertext, self._pq_priv) + except ValueError: + raise ProtocolError('Invalid PQ server ciphertext') from None + + try: + ec_shared = self._priv.get_shared_bytes(ec_pub) + except ValueError: + raise ProtocolError('Invalid ECDH server public key') from None + + return String(self._hash_alg(pq_secret + ec_shared).digest()) + + def _compute_server_shared(self) -> bytes: + """Compute server shared key""" + + pq_pub = self._client_pub[:self._pq.pubkey_bytes] + ec_pub = self._client_pub[self._pq.pubkey_bytes:] + + try: + pq_secret, pq_ciphertext = self._pq.encaps(pq_pub) + except ValueError: + raise ProtocolError('Invalid PQ client public key') from None + + try: + ec_shared = self._priv.get_shared_bytes(ec_pub) + except ValueError: + raise ProtocolError('Invalid ECDH client public key') from None + + self._server_pub = pq_ciphertext + self._server_pub + + return String(self._hash_alg(pq_secret + ec_shared).digest()) + + class _KexGSSBase(_KexDHBase): """Handler for GSS key exchange""" @@ -484,6 +535,7 @@ def __init__(self, alg: bytes, conn: 'SSHConnection', self._gss = conn.get_gss_context() self._token: Optional[bytes] = None + self._host_key_msg_ok = False self._host_key_data = b'' def _check_secure(self) -> None: @@ -521,11 +573,11 @@ def _send_continue(self) -> None: self.send_packet(MSG_KEXGSS_CONTINUE, String(self._token)) - def _process_token(self, token: Optional[bytes] = None) -> None: + async def _process_token(self, token: Optional[bytes] = None) -> None: """Process a GSS token""" try: - self._token = self._gss.step(token) + self._token = await run_in_executor(self._gss.step, token) except GSSError as exc: if self._conn.is_server(): self.send_packet(MSG_KEXGSS_ERROR, UInt32(exc.maj_code), @@ -537,8 +589,8 @@ def _process_token(self, token: Optional[bytes] = None) -> None: raise KeyExchangeFailed(str(exc)) from None - def _process_init(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_gss_init(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS init message""" if self._conn.is_client(): @@ -557,7 +609,7 @@ def _process_init(self, _pkttype: int, _pktid: int, else: self._host_key_data = b'' - self._process_token(token) + await self._process_token(token) if self._gss.complete: self._check_secure() @@ -566,8 +618,8 @@ def _process_init(self, _pkttype: int, _pktid: int, else: self._send_continue() - def _process_continue(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_continue(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS continue message""" token = packet.get_string() @@ -576,7 +628,9 @@ def _process_continue(self, _pkttype: int, _pktid: int, if self._conn.is_client() and self._gss.complete: raise ProtocolError('Unexpected kexgss continue msg') - self._process_token(token) + self._host_key_msg_ok = False + + await self._process_token(token) if self._conn.is_server() and self._gss.complete: self._check_secure() @@ -584,13 +638,15 @@ def _process_continue(self, _pkttype: int, _pktid: int, else: self._send_continue() - def _process_complete(self, _pkttype: int, _pktid: int, - packet: SSHPacket) -> None: + async def _process_complete(self, _pkttype: int, _pktid: int, + packet: SSHPacket) -> None: """Process a GSS complete message""" if self._conn.is_server(): raise ProtocolError('Unexpected kexgss complete msg') + self._host_key_msg_ok = False + self._parse_server_key(packet) mic = packet.get_string() token_present = packet.get_boolean() @@ -601,7 +657,7 @@ def _process_complete(self, _pkttype: int, _pktid: int, if self._gss.complete: raise ProtocolError('Non-empty token after complete') - self._process_token(token) + await self._process_token(token) if self._token: raise ProtocolError('Non-empty token after complete') @@ -617,6 +673,10 @@ def _process_hostkey(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: """Process a GSS hostkey message""" + if not self._host_key_msg_ok: + raise ProtocolError('Unexpected kexgss hostkey msg') + + self._host_key_msg_ok = False self._host_key_data = packet.get_string() packet.check_end() @@ -636,12 +696,13 @@ def _process_error(self, _pkttype: int, _pktid: int, self._conn.logger.debug1('GSS error: %s', msg.decode('utf-8', errors='ignore')) - def start(self) -> None: + async def start(self) -> None: """Start GSS key exchange""" if self._conn.is_client(): - self._process_token() - super().start() + self._host_key_msg_ok = True + await self._process_token() + await super().start() class _KexGSS(_KexGSSBase, _KexDH): @@ -650,7 +711,7 @@ class _KexGSS(_KexGSSBase, _KexDH): _handler_names = get_symbol_names(globals(), 'MSG_KEXGSS_') _packet_handlers = { - MSG_KEXGSS_INIT: _KexGSSBase._process_init, + MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init, MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, @@ -667,7 +728,7 @@ class _KexGSSGex(_KexGSSBase, _KexDHGex): _group_type = MSG_KEXGSS_GROUP _packet_handlers = { - MSG_KEXGSS_INIT: _KexGSSBase._process_init, + MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init, MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, @@ -683,7 +744,7 @@ class _KexGSSECDH(_KexGSSBase, _KexECDH): _handler_names = get_symbol_names(globals(), 'MSG_KEXGSS_') _packet_handlers = { - MSG_KEXGSS_INIT: _KexGSSBase._process_init, + MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init, MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue, MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete, MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey, @@ -691,7 +752,23 @@ class _KexGSSECDH(_KexGSSBase, _KexECDH): } +if mlkem_available: # pragma: no branch + if curve25519_available: # pragma: no branch + register_kex_alg(b'mlkem768x25519-sha256', _KexHybridECDH, + sha256, (b'mlkem768', Curve25519DH), True) + + register_kex_alg(b'mlkem768nistp256-sha256', _KexHybridECDH, + sha256, (b'mlkem768', ECDH, b'nistp256'), True) + register_kex_alg(b'mlkem1024nistp384-sha384', _KexHybridECDH, + sha384, (b'mlkem1024', ECDH, b'nistp384'), True) + if curve25519_available: # pragma: no branch + if sntrup_available: # pragma: no branch + register_kex_alg(b'sntrup761x25519-sha512', _KexHybridECDH, + sha512, (b'sntrup761', Curve25519DH), True) + register_kex_alg(b'sntrup761x25519-sha512@openssh.com', _KexHybridECDH, + sha512, (b'sntrup761', Curve25519DH), True) + register_kex_alg(b'curve25519-sha256', _KexECDH, sha256, (Curve25519DH,), True) register_kex_alg(b'curve25519-sha256@libssh.org', _KexECDH, sha256, diff --git a/asyncssh/kex_rsa.py b/asyncssh/kex_rsa.py index de678c5..6f5ea46 100644 --- a/asyncssh/kex_rsa.py +++ b/asyncssh/kex_rsa.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021 by Ron Frederick and others. +# Copyright (c) 2018-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -64,7 +64,7 @@ def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType, self._k = 0 self._encrypted_k = b'' - def start(self) -> None: + async def start(self) -> None: """Start RSA key exchange""" if self._conn.is_server(): @@ -142,7 +142,7 @@ def _process_secret(self, _pkttype: int, _pktid: int, self.send_packet(MSG_KEXRSA_DONE, String(sig)) - self._conn.send_newkeys(self._k, h) + self._conn.send_newkeys(MPInt(self._k), h) def _process_done(self, _pkttype: int, _pktid: int, packet: SSHPacket) -> None: @@ -161,7 +161,7 @@ def _process_done(self, _pkttype: int, _pktid: int, if not host_key.verify(h, sig): raise KeyExchangeFailed('Key exchange hash mismatch') - self._conn.send_newkeys(self._k, h) + self._conn.send_newkeys(MPInt(self._k), h) _packet_handlers = { MSG_KEXRSA_PUBKEY: _process_pubkey, diff --git a/asyncssh/known_hosts.py b/asyncssh/known_hosts.py index d80ac60..161028c 100644 --- a/asyncssh/known_hosts.py +++ b/asyncssh/known_hosts.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2021 by Ron Frederick and others. +# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -94,13 +94,13 @@ def __init__(self, pattern: str): self._salt = binascii.a2b_base64(salt) self._hosthash = binascii.a2b_base64(hosthash) except (ValueError, binascii.Error): - raise ValueError('Invalid known hosts hash entry: %s' % - pattern) from None + raise ValueError( + f'Invalid known hosts hash entry: {pattern}') from None if magic != self._HMAC_SHA1_MAGIC: # Only support HMAC SHA-1 for now - raise ValueError('Invalid known hosts hash type: %s' % - magic) from None + raise ValueError( + f'Invalid known hosts hash type: {magic}') from None def _match(self, value: str) -> bool: """Return whether this host hash matches a value""" @@ -141,12 +141,12 @@ def load(self, known_hosts: str) -> None: marker = None pattern, data = line.split(None, 1) except ValueError: - raise ValueError('Invalid known hosts entry: %s' % - line) from None + raise ValueError( + f'Invalid known hosts entry: {line}') from None if marker not in (None, 'cert-authority', 'revoked'): - raise ValueError('Invalid known hosts marker: %s' % - marker) from None + raise ValueError( + f'Invalid known hosts marker: {marker}') from None key: Optional[SSHKey] = None cert: Optional[SSHCertificate] = None @@ -208,8 +208,8 @@ def _match(self, host: str, addr: str, ip = None if port: - host = '[{}]:{}'.format(host, port) if host else '' - addr = '[{}]:{}'.format(addr, port) if addr else '' + host = f'[{host}]:{port}' if host else '' + addr = f'[{addr}]:{port}' if addr else '' matches = [] matches += self._exact_entries.get(host, []) diff --git a/asyncssh/listener.py b/asyncssh/listener.py index da963bc..e9cc475 100644 --- a/asyncssh/listener.py +++ b/asyncssh/listener.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -25,7 +25,8 @@ import socket from types import TracebackType from typing import TYPE_CHECKING, AnyStr, Callable, Generic, List, Optional -from typing import Sequence, Tuple, Type, Union +from typing import Sequence, Set, Tuple, Type, Union +from typing_extensions import Self from .forward import SSHForwarderCoro from .forward import SSHLocalPortForwarder, SSHLocalPathForwarder @@ -54,7 +55,7 @@ class SSHListener: def __init__(self) -> None: self._tunnel: Optional['SSHConnection'] = None - async def __aenter__(self) -> 'SSHListener': + async def __aenter__(self) -> Self: return self async def __aexit__(self, _exc_type: Optional[Type[BaseException]], @@ -184,6 +185,11 @@ def process_connection(self, orig_host: str, orig_port: int) -> \ return chan, self._session_factory(orig_host, orig_port) + def get_addresses(self) -> List[Tuple]: + """Return the socket addresses being listened on""" + + return [(self._listen_host, self._listen_port)] + def get_port(self) -> int: """Return the port number being listened on""" @@ -279,9 +285,19 @@ async def create_tcp_local_listener( if not addrinfo: # pragma: no cover raise OSError('getaddrinfo() returned empty list') + seen_addrinfo: Set[Tuple] = set() servers: List[asyncio.AbstractServer] = [] - for family, socktype, proto, _, sa in addrinfo: + for addrinfo_entry in addrinfo: + # Work around an issue where getaddrinfo() on some systems may + # return duplicate results, causing bind to fail. + if addrinfo_entry in seen_addrinfo: # pragma: no cover + continue + + seen_addrinfo.add(addrinfo_entry) + + family, socktype, proto, _, sa = addrinfo_entry + try: sock = socket.socket(family, socktype, proto) except OSError: # pragma: no cover @@ -311,9 +327,9 @@ async def create_tcp_local_listener( exc.strerror = str(exc) # type: ignore # pylint: disable=no-member - raise OSError(exc.errno, 'error while attempting ' # type: ignore - 'to bind on address %r: %s' % - (sa, exc.strerror)) from None # type: ignore + raise OSError(exc.errno, f'error while attempting ' # type: ignore + f'to bind on address {sa!r}: ' + f'{exc.strerror}') from None # type: ignore if listen_port == 0: listen_port = sock.getsockname()[1] diff --git a/asyncssh/logging.py b/asyncssh/logging.py index 828d85c..bdf87ca 100644 --- a/asyncssh/logging.py +++ b/asyncssh/logging.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -61,29 +61,39 @@ def get_child(self, child: str = '', context: str = '') -> 'SSHLogger': def log(self, level: int, msg: object, *args, **kwargs) -> None: """Log a message to the underlying logger""" - def _text(arg: _LogArg) -> str: + def _item_text(item: _LogArg) -> str: + """Convert a list item to text""" + + if isinstance(item, bytes): + result = item.decode('utf-8', errors='backslashreplace') + + if not result.isprintable(): + result = repr(result)[1:-1] + elif not isinstance(item, str): + result = str(item) + else: + result = item + + return result + + def _text(arg: _LogArg) -> _LogArg: """Convert a log argument to text""" + result: _LogArg + if isinstance(arg, list): - if arg and isinstance(arg[0], bytes): - result = b','.join(arg).decode('utf-8', errors='replace') - else: - result = ','.join(arg) + result = ','.join(_item_text(item) for item in arg) elif isinstance(arg, tuple): host, port = arg if host: - result = '%s, port %d' % (host, port) if port else host + result = f'{host}, port {port}' if port else host else: - result = 'port %d' % port if port else 'dynamic port' + result = f'port {port}' if port else 'dynamic port' + elif isinstance(arg, bytes): + result = _item_text(arg) else: - result = cast(str, arg) - - if isinstance(result, bytes): - result = result.decode('ascii', errors='backslashreplace') - - if not result.isprintable(): - result = repr(result)[1:-1] + result = arg return result @@ -104,10 +114,10 @@ def process(self, msg: str, kwargs: _ObjDict) -> Tuple[str, _ObjDict]: offset = 0 while packet: - line = '\n %08x:' % offset + line = f'\n {offset:08x}:' for b in packet[:16]: - line += ' %02x' % b + line += f' {b:02x}' line += (62 - len(line)) * ' ' @@ -157,7 +167,7 @@ def packet(self, pktid: Optional[int], packet: bytes, msg: str, extra = cast(_ObjDict, kwargs.get('extra')) if pktid is not None: - extra.update(context='pktid=%d' % pktid) + extra.update(context=f'pktid={pktid}') extra.update(packet=packet) diff --git a/asyncssh/misc.py b/asyncssh/misc.py index 22c75d9..55f6386 100644 --- a/asyncssh/misc.py +++ b/asyncssh/misc.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2022 by Ron Frederick and others. +# Copyright (c) 2013-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -20,17 +20,22 @@ """Miscellaneous utility classes and functions""" +import asyncio +import fnmatch import functools import ipaddress +import os import re +import shlex import socket +import sys from pathlib import Path, PurePath from random import SystemRandom from types import TracebackType from typing import Any, AsyncContextManager, Awaitable, Callable, Dict -from typing import Generator, Generic, IO, Mapping, Optional, Sequence -from typing import Tuple, Type, TypeVar, Union, cast, overload +from typing import Generator, Generic, IO, Iterator, Mapping, Optional +from typing import Sequence, Tuple, Type, TypeVar, Union, cast, overload from typing_extensions import Literal, Protocol from .constants import DEFAULT_LANG @@ -41,6 +46,27 @@ from .constants import DISC_PROTOCOL_ERROR, DISC_PROTOCOL_VERSION_NOT_SUPPORTED from .constants import DISC_SERVICE_NOT_AVAILABLE +_pywin32_available = False + +if sys.platform == 'win32': # pragma: no cover + try: + import msvcrt + import win32file + import winioctlcon + _pywin32_available = True + except ImportError: + pass + +if sys.platform != 'win32': # pragma: no branch + import fcntl + import struct + import termios + +TermModes = Mapping[int, int] +TermModesArg = Optional[TermModes] +TermSize = Tuple[int, int, int, int] +TermSizeArg = Union[None, Tuple[int, int], TermSize] + class _Hash(Protocol): """Protocol for hashing data""" @@ -89,12 +115,17 @@ async def wait_closed(self) -> None: OptExcInfo = Union[ExcInfo, Tuple[None, None, None]] BytesOrStr = Union[bytes, str] +BytesOrStrDict = Dict[BytesOrStr, BytesOrStr] FilePath = Union[str, PurePath] HostPort = Tuple[str, int] IPAddress = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] IPNetwork = Union[ipaddress.IPv4Network, ipaddress.IPv6Network] SockAddr = Union[Tuple[str, int], Tuple[str, int, int, int]] +EnvMap = Mapping[BytesOrStr, BytesOrStr] +EnvItems = Sequence[Tuple[BytesOrStr, BytesOrStr]] +EnvSeq = Sequence[BytesOrStr] +Env = Union[EnvMap, EnvItems, EnvSeq] # Define a version of randrange which is based on SystemRandom(), so that # we get back numbers suitable for cryptographic use. @@ -107,6 +138,62 @@ async def wait_closed(self) -> None: 'd': 24*60*60, 'w': 7*24*60*60} +def encode_env(env: Env) -> Iterator[Tuple[bytes, bytes]]: + """Convert environemnt dict or list to bytes-based dictionary""" + + if hasattr(env, 'items'): + env = cast(Env, env.items()) + + try: + for item in env: + if isinstance(item, (bytes, str)): + if isinstance(item, str): + item = item.encode('utf-8') + + key_bytes, value_bytes = item.split(b'=', 1) + else: + key, value = item + + key_bytes = key.encode('utf-8') \ + if isinstance(key, str) else key + + value_bytes = value.encode('utf-8') \ + if isinstance(value, str) else value + + yield key_bytes, value_bytes + except (TypeError, ValueError) as exc: + raise ValueError(f'Invalid environment value: {exc}') from None + + +def lookup_env(patterns: EnvSeq) -> Iterator[Tuple[bytes, bytes]]: + """Look up environemnt variables with wildcard matches""" + + for pattern in patterns: + if isinstance(pattern, str): + pattern = pattern.encode('utf-8') + + if os.supports_bytes_environ: + for key_bytes, value_bytes in os.environb.items(): + if fnmatch.fnmatch(key_bytes, pattern): + yield key_bytes, value_bytes + else: # pragma: no cover + for key, value in os.environ.items(): + key_bytes = key.encode('utf-8') + value_bytes = value.encode('utf-8') + if fnmatch.fnmatch(key_bytes, pattern): + yield key_bytes, value_bytes + + +def decode_env(env: Dict[bytes, bytes]) -> Iterator[Tuple[str, str]]: + """Convert bytes-based environemnt dict to Unicode strings""" + + for key, value in env.items(): + try: + yield key.decode('utf-8'), value.decode('utf-8') + except UnicodeDecodeError: + pass + + def hide_empty(value: object, prefix: str = ', ') -> str: """Return a string with optional prefix if value is non-empty""" @@ -117,7 +204,7 @@ def hide_empty(value: object, prefix: str = ', ') -> str: def plural(length: int, label: str, suffix: str = 's') -> str: """Return a label with an optional plural suffix""" - return '%d %s%s' % (length, label, suffix if length != 1 else '') + return f'{length} {label}{suffix if length != 1 else ""}' def all_ints(seq: Sequence[object]) -> bool: @@ -229,6 +316,19 @@ def write_file(filename: FilePath, data: bytes, mode: str = 'wb') -> int: return f.write(data) +if sys.platform == 'win32' and _pywin32_available: # pragma: no cover + def make_sparse_file(file_obj: IO) -> None: + """Enable sparse file support on a file on Windows""" + + handle = msvcrt.get_osfhandle(file_obj.fileno()) + + win32file.DeviceIoControl(handle, winioctlcon.FSCTL_SET_SPARSE, + b'', 0, None) +else: + def make_sparse_file(_file_obj: IO) -> None: + """Sparse files are automatically enabled on non-Windows systems""" + + def _parse_units(value: str, suffixes: Mapping[str, int], label: str) -> float: """Parse a series of integers followed by unit suffixes""" @@ -258,6 +358,18 @@ def parse_time_interval(value: str) -> float: return _parse_units(value, _time_units, 'time interval') +def split_args(command: str) -> Sequence[str]: + """Split a command string into a list of arguments""" + + lex = shlex.shlex(command, posix=True) + lex.whitespace_split = True + + if sys.platform == 'win32': # pragma: no cover + lex.escape = [] + + return list(lex) + + _ACM = TypeVar('_ACM', bound=AsyncContextManager, covariant=True) class _ACMWrapper(Generic[_ACM]): @@ -331,6 +443,22 @@ async def maybe_wait_closed(writer: '_SupportsWaitClosed') -> None: pass +async def run_in_executor(func: Callable[..., _T], *args: object) -> _T: + """Run a function in an asyncio executor""" + + loop = asyncio.get_event_loop() + + return await loop.run_in_executor(None, func, *args) + + +def set_terminal_size(tty: IO, width: int, height: int, + pixwidth: int, pixheight: int) -> None: + """Set the terminal size of a TTY""" + + fcntl.ioctl(tty, termios.TIOCSWINSZ, + struct.pack('hhhh', height, width, pixwidth, pixheight)) + + class Options: """Container for configuration options""" @@ -339,8 +467,8 @@ class Options: def __init__(self, options: Optional['Options'] = None, **kwargs: object): if options: if not isinstance(options, type(self)): - raise TypeError('Invalid %s, got %s' % - (type(self).__name__, type(options).__name__)) + raise TypeError(f'Invalid {type(self).__name__}, ' + f'got {type(options).__name__}') self.kwargs = options.kwargs.copy() else: @@ -352,7 +480,7 @@ def __init__(self, options: Optional['Options'] = None, **kwargs: object): def prepare(self, **kwargs: object) -> None: """Pre-process configuration options""" - def update(self, kwargs: Dict[str, object]) -> None: + def update(self, **kwargs: object) -> None: """Update options based on keyword parameters passed in""" self.kwargs.update(kwargs) @@ -362,17 +490,18 @@ def update(self, kwargs: Dict[str, object]) -> None: class _RecordMeta(type): """Metaclass for general-purpose record type""" + __slots__: Dict[str, object] = {} + def __new__(mcs: Type['_RecordMeta'], name: str, bases: Tuple[type, ...], ns: Dict[str, object]) -> '_RecordMeta': + cls = cast(_RecordMeta, super().__new__(mcs, name, bases, ns)) + if name != 'Record': - fields = cast(Mapping[str, str], - ns.get('__annotations__', {})).keys() + fields = cast(Mapping[str, str], cls.__annotations__.keys()) defaults = {k: ns.get(k) for k in fields} + cls.__slots__ = defaults - ns = {k: v for k, v in ns.items() if k not in fields} - ns['__slots__'] = defaults - - return cast(_RecordMeta, super().__new__(mcs, name, bases, ns)) + return cls class Record(metaclass=_RecordMeta): @@ -391,15 +520,15 @@ def __init__(self, *args: object, **kwargs: object): setattr(self, k, v) def __repr__(self) -> str: - return '%s(%s)' % (type(self).__name__, - ', '.join('%s=%r' % (k, getattr(self, k)) - for k in self.__slots__)) + values = ', '.join(f'{k}={getattr(self, k)!r}' for k in self.__slots__) + + return f'{type(self).__name__}({values})' def __str__(self) -> str: values = ((k, self._format(k, getattr(self, k))) for k in self.__slots__) - return ', '.join('%s: %s' % (k, v) for k, v in values if v is not None) + return ', '.join(f'{k}: {v}' for k, v in values if v is not None) def _format(self, k: str, v: object) -> Optional[str]: """Format a field as a string""" @@ -684,7 +813,7 @@ class PasswordChangeRequired(Exception): """ def __init__(self, prompt: str, lang: str = DEFAULT_LANG): - super().__init__('Password change required: %s' % prompt) + super().__init__(f'Password change required: {prompt}') self.prompt = prompt self.lang = lang @@ -702,7 +831,7 @@ class BreakReceived(Exception): """ def __init__(self, msec: int): - super().__init__('Break for %s msec' % msec) + super().__init__(f'Break for {msec} msec') self.msec = msec @@ -719,7 +848,7 @@ class SignalReceived(Exception): """ def __init__(self, signal: str): - super().__init__('Signal: %s' % signal) + super().__init__(f'Signal: {signal}') self.signal = signal @@ -757,13 +886,19 @@ class TerminalSizeChanged(Exception): """ def __init__(self, width: int, height: int, pixwidth: int, pixheight: int): - super().__init__('Terminal size change: (%s, %s, %s, %s)' % - (width, height, pixwidth, pixheight)) + super().__init__(f'Terminal size change: ({width}, {height}, ' + f'{pixwidth}, {pixheight})') self.width = width self.height = height self.pixwidth = pixwidth self.pixheight = pixheight + @property + def term_size(self) -> TermSize: + """Return terminal size as a tuple of 4 integers""" + + return self.width, self.height, self.pixwidth, self.pixheight + _disc_error_map = { DISC_PROTOCOL_ERROR: ProtocolError, @@ -785,4 +920,4 @@ def construct_disc_error(code: int, reason: str, lang: str) -> DisconnectError: try: return _disc_error_map[code](reason, lang) except KeyError: - return DisconnectError(code, '%s (error %d)' % (reason, code), lang) + return DisconnectError(code, f'{reason} (error {code})', lang) diff --git a/asyncssh/packet.py b/asyncssh/packet.py index 916348d..d30d74b 100644 --- a/asyncssh/packet.py +++ b/asyncssh/packet.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -20,14 +20,15 @@ """SSH packet encoding and decoding functions""" -from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Union +from typing import Any, Awaitable, Callable, Iterable, Mapping, Optional +from typing import Sequence, Union from .logging import SSHLogger -from .misc import plural +from .misc import MaybeAwait, plural _LoggedPacket = Union[bytes, 'SSHPacket'] -_PacketHandler = Callable[[Any, int, int, 'SSHPacket'], None] +_PacketHandler = Callable[[Any, int, int, 'SSHPacket'], MaybeAwait[None]] class PacketDecodeError(ValueError): @@ -192,14 +193,14 @@ def _log_packet(self, msg: str, pkttype: int, pktid: Optional[int], packet = packet.get_full_payload() try: - name = '%s (%d)' % (self._handler_names[pkttype], pkttype) + name = f'{self._handler_names[pkttype]} ({pkttype})' except KeyError: - name = 'packet type %d' % pkttype + name = f'packet type {pkttype}' count = plural(len(packet), 'byte') if note: - note = ' (%s)' % note + note = f' ({note})' self.logger.packet(pktid, packet, '%s %s, %s%s', msg, name, count, note) @@ -230,11 +231,11 @@ def logger(self) -> SSHLogger: raise NotImplementedError def process_packet(self, pkttype: int, pktid: int, - packet: SSHPacket) -> bool: + packet: SSHPacket) -> Union[bool, Awaitable[None]]: """Log and process a received packet""" if pkttype in self._packet_handlers: - self._packet_handlers[pkttype](self, pkttype, pktid, packet) - return True + return self._packet_handlers[pkttype](self, pkttype, + pktid, packet) or True else: return False diff --git a/asyncssh/pkcs11.py b/asyncssh/pkcs11.py index 0604953..b9a4740 100644 --- a/asyncssh/pkcs11.py +++ b/asyncssh/pkcs11.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021 by Ron Frederick and others. +# Copyright (c) 2020-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -68,7 +68,7 @@ class SSHPKCS11KeyPair(SSHKeyPair): _key_type = 'pkcs11' def __init__(self, session: 'SSHPKCS11Session', privkey: PrivateKey, - pubkey: SSHKey, cert: SSHCertificate = None): + pubkey: SSHKey, cert: Optional[SSHCertificate] = None): super().__init__(pubkey.algorithm, pubkey.algorithm, pubkey.sig_algorithms, pubkey.sig_algorithms, pubkey.public_data, privkey.label, cert, @@ -100,7 +100,7 @@ def sign(self, data: bytes) -> bytes: class SSHPKCS11Session: - """Work around PKCS#11 sesssions not supporting simultaneous opens""" + """Work around PKCS#11 sessions not supporting simultaneous opens""" _sessions: _SessionMap = {} @@ -197,12 +197,12 @@ def get_keys(self, load_certs: bool, key_label: Optional[str], return keys - def load_pkcs11_keys(provider: str, pin: str = None, *, + def load_pkcs11_keys(provider: str, pin: Optional[str] = None, *, load_certs: bool = True, - token_label: str = None, - token_serial: BytesOrStr = None, - key_label: str = None, - key_id: BytesOrStr = None) -> \ + token_label: Optional[str] = None, + token_serial: Optional[BytesOrStr] = None, + key_label: Optional[str] = None, + key_id: Optional[BytesOrStr] = None) -> \ Sequence[SSHPKCS11KeyPair]: """Load PIV keys and X.509 certificates from a PKCS#11 token @@ -278,12 +278,12 @@ def load_pkcs11_keys(provider: str, pin: str = None, *, return keys else: # pragma: no cover - def load_pkcs11_keys(provider: str, pin: str = None, *, + def load_pkcs11_keys(provider: str, pin: Optional[str] = None, *, load_certs: bool = True, - token_label: str = None, - token_serial: BytesOrStr = None, - key_label: str = None, - key_id: BytesOrStr = None) -> \ + token_label: Optional[str] = None, + token_serial: Optional[BytesOrStr] = None, + key_label: Optional[str] = None, + key_id: Optional[BytesOrStr] = None) -> \ Sequence['SSHPKCS11KeyPair']: """Report that PKCS#11 support is not available""" diff --git a/asyncssh/process.py b/asyncssh/process.py index d9d29e7..95282bd 100644 --- a/asyncssh/process.py +++ b/asyncssh/process.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016-2021 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -30,10 +30,10 @@ import socket import stat from types import TracebackType -from typing import Any, AnyStr, Callable, Dict, Generic, IO -from typing import Iterable, Mapping, Optional, Set, TextIO +from typing import Any, AnyStr, Awaitable, Callable, Dict, Generic, IO +from typing import Iterable, List, Mapping, Optional, Set, TextIO from typing import Tuple, Type, TypeVar, Union, cast -from typing_extensions import Protocol +from typing_extensions import Protocol, Self from .channel import SSHChannel, SSHClientChannel, SSHServerChannel @@ -41,16 +41,16 @@ from .logging import SSHLogger -from .misc import BytesOrStr, Error, MaybeAwait -from .misc import ProtocolError, Record, open_file +from .misc import BytesOrStr, Error, MaybeAwait, TermModes, TermSize +from .misc import ProtocolError, Record, open_file, set_terminal_size +from .misc import BreakReceived, SignalReceived, TerminalSizeChanged -from .session import DataType, TermModes, TermSize +from .session import DataType from .stream import SSHReader, SSHWriter, SSHStreamSession from .stream import SSHClientStreamSession, SSHServerStreamSession from .stream import SFTPServerFactory - _AnyStrContra = TypeVar('_AnyStrContra', bytes, str, contravariant=True) _File = Union[IO[bytes], '_AsyncFileProtocol[bytes]'] @@ -61,7 +61,12 @@ ProcessTarget = Union[int, str, socket.socket, PurePath, SSHWriter[bytes], asyncio.StreamWriter, _File] -SSHServerProcessFactory = Callable[['SSHServerProcess[AnyStr]'], None] +SSHServerProcessFactory = Callable[['SSHServerProcess[AnyStr]'], + MaybeAwait[None]] + + +_QUEUE_LOW_WATER = 8 +_QUEUE_HIGH_WATER = 16 class _AsyncFileProtocol(Protocol[AnyStr]): @@ -96,6 +101,11 @@ class _WriterProtocol(Protocol[_AnyStrContra]): def write(self, data: _AnyStrContra) -> None: """Write data""" + def write_exception(self, exc: Exception) -> None: + """Write exception (break, signal, terminal size change)""" + + return # pragma: no cover + def write_eof(self) -> None: """Close output when end of file is received""" @@ -111,11 +121,13 @@ def _is_regular_file(file: IO[bytes]) -> bool: except OSError: return True -class _UnicodeReader(Generic[AnyStr]): +class _UnicodeReader(_ReaderProtocol, Generic[AnyStr]): """Handle buffering partial Unicode data""" def __init__(self, encoding: Optional[str], errors: str, textmode: bool = False): + super().__init__() + if encoding and not textmode: self._decoder: Optional[codecs.IncrementalDecoder] = \ codecs.getincrementaldecoder(encoding)(errors) @@ -144,11 +156,13 @@ def close(self) -> None: """Perform necessary cleanup on error (provided by derived classes)""" -class _UnicodeWriter(Generic[AnyStr]): +class _UnicodeWriter(_WriterProtocol[AnyStr]): """Handle encoding Unicode data before writing it""" def __init__(self, encoding: Optional[str], errors: str, textmode: bool = False): + super().__init__() + if encoding and not textmode: self._encoder: Optional[codecs.IncrementalEncoder] = \ codecs.getincrementalencoder(encoding)(errors) @@ -265,10 +279,12 @@ def close(self) -> None: class _FileWriter(_UnicodeWriter[AnyStr]): """Forward data to a file""" - def __init__(self, file: IO[bytes], encoding: Optional[str], errors: str): + def __init__(self, file: IO[bytes], needs_close: bool, + encoding: Optional[str], errors: str): super().__init__(encoding, errors, hasattr(file, 'encoding')) self._file = file + self._needs_close = needs_close def write(self, data: AnyStr) -> None: """Write data to the file""" @@ -283,24 +299,55 @@ def write_eof(self) -> None: def close(self) -> None: """Stop forwarding data to the file""" - self._file.close() + if self._needs_close: + self._file.close() class _AsyncFileWriter(_UnicodeWriter[AnyStr]): """Forward data to an aiofile""" def __init__(self, process: 'SSHProcess[AnyStr]', - file: _AsyncFileProtocol[bytes], - encoding: Optional[str], errors: str): + file: _AsyncFileProtocol[bytes], needs_close: bool, + datatype: Optional[int], encoding: Optional[str], errors: str): super().__init__(encoding, errors, hasattr(file, 'encoding')) - self._conn = process.channel.get_connection() + self._process: 'SSHProcess[AnyStr]' = process self._file = file + self._needs_close = needs_close + self._datatype = datatype + self._paused = False + self._queue: asyncio.Queue[Optional[AnyStr]] = asyncio.Queue() + self._write_task: Optional[asyncio.Task[None]] = \ + process.channel.get_connection().create_task(self._writer()) + + async def _writer(self) -> None: + """Process writes to the file""" + + while True: + data = await self._queue.get() + + if data is None: + self._queue.task_done() + break + + await self._file.write(self.encode(data)) + self._queue.task_done() + + if self._paused and self._queue.qsize() < _QUEUE_LOW_WATER: + self._process.resume_feeding(self._datatype) + self._paused = False + + if self._needs_close: + await self._file.close() def write(self, data: AnyStr) -> None: """Write data to the file""" - self._conn.create_task(self._file.write(self.encode(data))) + self._queue.put_nowait(data) + + if not self._paused and self._queue.qsize() >= _QUEUE_HIGH_WATER: + self._paused = True + self._process.pause_feeding(self._datatype) def write_eof(self) -> None: """Close output file when end of file is received""" @@ -310,7 +357,10 @@ def write_eof(self) -> None: def close(self) -> None: """Stop forwarding data to the file""" - self._conn.create_task(self._file.close()) + if self._write_task: + self._write_task = None + self._queue.put_nowait(None) + self._process.add_cleanup_task(self._queue.join()) class _PipeReader(_UnicodeReader[AnyStr], asyncio.BaseProtocol): @@ -329,6 +379,12 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: self._transport = cast(asyncio.ReadTransport, transport) + def connection_lost(self, exc: Optional[Exception]) -> None: + """Handle closing of the pipe""" + + self._process.feed_close(self._datatype) + self.close() + def data_received(self, data: bytes) -> None: """Forward data from the pipe""" @@ -369,12 +425,25 @@ def __init__(self, process: 'SSHProcess[AnyStr]', datatype: DataType, self._process: 'SSHProcess[AnyStr]' = process self._datatype = datatype self._transport: Optional[asyncio.WriteTransport] = None + self._tty: Optional[IO] = None + self._close_event = asyncio.Event() def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle a newly opened pipe""" self._transport = cast(asyncio.WriteTransport, transport) + pipe = transport.get_extra_info('pipe') + + if isinstance(self._process, SSHServerProcess) and pipe.isatty(): + self._tty = pipe + set_terminal_size(pipe, *self._process.term_size) + + def connection_lost(self, exc: Optional[Exception]) -> None: + """Handle closing of the pipe""" + + self._close_event.set() + def pause_writing(self) -> None: """Pause writing to the pipe""" @@ -391,6 +460,12 @@ def write(self, data: AnyStr) -> None: assert self._transport is not None self._transport.write(self.encode(data)) + def write_exception(self, exc: Exception) -> None: + """Write terminal size changes to the pipe if it is a TTY""" + + if isinstance(exc, TerminalSizeChanged) and self._tty: + set_terminal_size(self._tty, *exc.term_size) + def write_eof(self) -> None: """Write EOF to the pipe""" @@ -402,12 +477,14 @@ def close(self) -> None: assert self._transport is not None self._transport.close() + self._process.add_cleanup_task(self._close_event.wait()) -class _ProcessReader(Generic[AnyStr]): +class _ProcessReader(_ReaderProtocol, Generic[AnyStr]): """Forward data from another SSH process""" def __init__(self, process: 'SSHProcess[AnyStr]', datatype: DataType): + super().__init__() self._process: 'SSHProcess[AnyStr]' = process self._datatype = datatype @@ -427,10 +504,11 @@ def close(self) -> None: self._process.clear_writer(self._datatype) -class _ProcessWriter(Generic[AnyStr]): +class _ProcessWriter(_WriterProtocol[AnyStr]): """Forward data to another SSH process""" def __init__(self, process: 'SSHProcess[AnyStr]', datatype: DataType): + super().__init__() self._process: 'SSHProcess[AnyStr]' = process self._datatype = datatype @@ -439,6 +517,11 @@ def write(self, data: AnyStr) -> None: self._process.feed_data(data, self._datatype) + def write_exception(self, exc: Exception) -> None: + """Write an exception to the other channel""" + + cast(SSHClientProcess, self._process).feed_exception(exc) + def write_eof(self) -> None: """Write EOF to the other channel""" @@ -502,27 +585,65 @@ def close(self) -> None: class _StreamWriter(_UnicodeWriter[AnyStr]): """Forward data to an asyncio stream""" - def __init__(self, writer: asyncio.StreamWriter, - encoding: Optional[str], errors: str): + def __init__(self, process: 'SSHProcess[AnyStr]', + writer: asyncio.StreamWriter, recv_eof: bool, + datatype: Optional[int], encoding: Optional[str], errors: str): super().__init__(encoding, errors) + self._process: 'SSHProcess[AnyStr]' = process self._writer = writer + self._recv_eof = recv_eof + self._datatype = datatype + self._paused = False + self._queue: asyncio.Queue[Optional[AnyStr]] = asyncio.Queue() + self._write_task: Optional[asyncio.Task[None]] = \ + process.channel.get_connection().create_task(self._feed()) + + async def _feed(self) -> None: + """Feed data to the stream""" + + while True: + data = await self._queue.get() + + if data is None: + self._queue.task_done() + break + + self._writer.write(self.encode(data)) + await self._writer.drain() + self._queue.task_done() + + if self._paused and self._queue.qsize() < _QUEUE_LOW_WATER: + self._process.resume_feeding(self._datatype) + self._paused = False + + if self._recv_eof: + self._writer.write_eof() def write(self, data: AnyStr) -> None: """Write data to the stream""" - self._writer.write(self.encode(data)) + self._queue.put_nowait(data) + + if not self._paused and self._queue.qsize() >= _QUEUE_HIGH_WATER: + self._paused = True + self._process.pause_feeding(self._datatype) def write_eof(self) -> None: """Write EOF to the stream""" - self._writer.write_eof() + self.close() def close(self) -> None: - """Ignore close -- the caller must clean up the associated transport""" + """Stop forwarding data to the stream""" + if self._write_task: + self._write_task = None + self._queue.put_nowait(None) + self._process.add_cleanup_task(self._queue.join()) -class _DevNullWriter(Generic[AnyStr]): + +class _DevNullWriter(_WriterProtocol[AnyStr]): """Discard data""" def write(self, data: AnyStr) -> None: @@ -535,10 +656,11 @@ def close(self) -> None: """Ignore close""" -class _StdoutWriter(Generic[AnyStr]): +class _StdoutWriter(_WriterProtocol[AnyStr]): """Forward data to an SSH process' stdout instead of stderr""" def __init__(self, process: 'SSHProcess[AnyStr]'): + super().__init__() self._process: 'SSHProcess[AnyStr]' = process def write(self, data: AnyStr) -> None: @@ -607,12 +729,11 @@ def __init__(self, env: Optional[Mapping[str, str]], if exit_signal: signal, core_dumped, msg, lang = exit_signal - reason = 'Process exited with signal %s%s%s' % \ - (signal, ': ' + msg if msg else '', - ' (core dumped)' if core_dumped else '') + reason = 'Process exited with signal ' + signal + \ + (': ' + msg if msg else '') + \ + (' (core dumped)' if core_dumped else '') elif exit_status: - reason = 'Process exited with non-zero exit status %s' % \ - exit_status + reason = f'Process exited with non-zero exit status {exit_status}' super().__init__(exit_status or 0, reason, lang) @@ -685,12 +806,30 @@ class SSHProcess(SSHStreamSession, Generic[AnyStr]): def __init__(self, *args) -> None: super().__init__(*args) + self._cleanup_tasks: List[Awaitable[None]] = [] + self._readers: Dict[Optional[int], _ReaderProtocol] = {} self._send_eof: Dict[Optional[int], bool] = {} self._writers: Dict[Optional[int], _WriterProtocol[AnyStr]] = {} + self._recv_eof: Dict[Optional[int], bool] = {} + self._paused_write_streams: Set[Optional[int]] = set() + async def __aenter__(self) -> Self: + """Allow SSHProcess to be used as an async context manager""" + + return self + + async def __aexit__(self, _exc_type: Optional[Type[BaseException]], + _exc_value: Optional[BaseException], + _traceback: Optional[TracebackType]) -> bool: + """Wait for a full channel close when exiting the async context""" + + self.close() + await self.wait_closed() + return False + @property def channel(self) -> SSHChannel[AnyStr]: """The channel associated with the process""" @@ -749,8 +888,8 @@ def get_extra_info(self, name: str, default: Any = None) -> Any: assert self._chan is not None return self._chan.get_extra_info(name, default) - async def _create_reader(self, source: ProcessSource, - bufsize: int, send_eof: bool, + async def _create_reader(self, source: ProcessSource, bufsize: int, + send_eof: bool, recv_eof: bool, datatype: DataType = None) -> None: """Create a reader to forward data to the SSH channel""" @@ -769,7 +908,7 @@ def pipe_factory() -> _PipeReader: reader_stream, reader_datatype = source.get_redirect_info() reader_process = cast('SSHProcess[AnyStr]', reader_stream) writer = _ProcessWriter[AnyStr](self, datatype) - reader_process.set_writer(writer, reader_datatype) + reader_process.set_writer(writer, recv_eof, reader_datatype) reader = _ProcessReader(reader_process, reader_datatype) elif isinstance(source, asyncio.StreamReader): reader = _StreamReader(self, source, bufsize, datatype, @@ -789,7 +928,7 @@ def pipe_factory() -> _PipeReader: file = source if hasattr(file, 'read') and \ - (asyncio.iscoroutinefunction(file.read) or + (inspect.iscoroutinefunction(file.read) or inspect.isgeneratorfunction(file.read)): reader = _AsyncFileReader(self, cast(_AsyncFileProtocol, file), bufsize, datatype, self._encoding, @@ -814,8 +953,8 @@ def pipe_factory() -> _PipeReader: elif isinstance(reader, _ProcessReader): reader_process.feed_recv_buf(reader_datatype, writer) - async def _create_writer(self, target: ProcessTarget, - bufsize: int, send_eof: bool, + async def _create_writer(self, target: ProcessTarget, bufsize: int, + send_eof: bool, recv_eof: bool, datatype: DataType = None) -> None: """Create a writer to forward data from the SSH channel""" @@ -837,40 +976,50 @@ def pipe_factory() -> _PipeWriter: writer_process.set_reader(reader, send_eof, writer_datatype) writer = _ProcessWriter[AnyStr](writer_process, writer_datatype) elif isinstance(target, asyncio.StreamWriter): - writer = _StreamWriter(target, self._encoding, self._errors) + writer = _StreamWriter(self, target, recv_eof, datatype, + self._encoding, self._errors) else: file: _File + needs_close = True if isinstance(target, str): file = open_file(target, 'wb', buffering=bufsize) elif isinstance(target, PurePath): file = open_file(str(target), 'wb', buffering=bufsize) elif isinstance(target, int): - file = os.fdopen(target, 'wb', buffering=bufsize) + file = os.fdopen(target, 'wb', + buffering=bufsize, closefd=recv_eof) elif isinstance(target, socket.socket): - file = os.fdopen(target.detach(), 'wb', buffering=bufsize) + fd = target.detach() if recv_eof else target.fileno() + file = os.fdopen(fd, 'wb', buffering=bufsize, closefd=recv_eof) else: file = target + needs_close = recv_eof if hasattr(file, 'write') and \ - (asyncio.iscoroutinefunction(file.write) or + (inspect.iscoroutinefunction(file.write) or inspect.isgeneratorfunction(file.write)): - writer = _AsyncFileWriter(self, cast(_AsyncFileProtocol, file), - self._encoding, self._errors) + writer = _AsyncFileWriter( + self, cast(_AsyncFileProtocol, file), needs_close, + datatype, self._encoding, self._errors) elif _is_regular_file(cast(IO[bytes], file)): - writer = _FileWriter(cast(IO[bytes], file), self._encoding, - self._errors) + writer = _FileWriter(cast(IO[bytes], file), needs_close, + self._encoding, self._errors) else: if hasattr(target, 'buffer'): # If file was opened in text mode, remove that wrapper file = cast(TextIO, target).buffer + if not recv_eof: + fd = os.dup(cast(IO[bytes], file).fileno()) + file = os.fdopen(fd, 'wb', buffering=0) + assert self._loop is not None _, protocol = \ await self._loop.connect_write_pipe(pipe_factory, file) writer = cast(_PipeWriter, protocol) - self.set_writer(writer, datatype) + self.set_writer(writer, recv_eof, datatype) if writer: self.feed_recv_buf(datatype, writer) @@ -887,6 +1036,11 @@ def _should_pause_reading(self) -> bool: return bool(self._paused_write_streams) or \ super()._should_pause_reading() + def add_cleanup_task(self, task: Awaitable) -> None: + """Add a task to run when the process exits""" + + self._cleanup_tasks.append(task) + def connection_lost(self, exc: Optional[Exception]) -> None: """Handle a close of the SSH channel""" @@ -914,8 +1068,9 @@ def data_received(self, data: AnyStr, datatype: DataType) -> None: def eof_received(self) -> bool: """Handle an incoming end of file from the SSH channel""" - for writer in list(self._writers.values()): - writer.write_eof() + for datatype, writer in list(self._writers.items()): + if self._recv_eof[datatype]: + writer.write_eof() return super().eof_received() @@ -951,14 +1106,22 @@ def feed_eof(self, datatype: DataType) -> None: self._readers[datatype].close() self.clear_reader(datatype) + def feed_close(self, datatype: DataType) -> None: + """Feed pipe close to the channel""" + + if datatype in self._readers: + self.feed_eof(datatype) + def feed_recv_buf(self, datatype: DataType, writer: _WriterProtocol[AnyStr]) -> None: """Feed current receive buffer to a newly set writer""" for buf in self._recv_buf[datatype]: - data = cast(AnyStr, buf) - writer.write(data) - self._recv_buf_len -= len(data) + if isinstance(buf, Exception): + writer.write_exception(buf) + else: + writer.write(buf) + self._recv_buf_len -= len(buf) self._recv_buf[datatype].clear() @@ -1005,7 +1168,7 @@ def clear_reader(self, datatype: DataType) -> None: self._unblock_drain(datatype) def set_writer(self, writer: Optional[_WriterProtocol[AnyStr]], - datatype: DataType) -> None: + recv_eof: bool, datatype: DataType) -> None: """Set a writer used to forward data from the channel""" old_writer = self._writers.get(datatype) @@ -1016,6 +1179,7 @@ def set_writer(self, writer: Optional[_WriterProtocol[AnyStr]], if writer: self._writers[datatype] = writer + self._recv_eof[datatype] = recv_eof def clear_writer(self, datatype: DataType) -> None: """Clear a writer forwarding data from the channel""" @@ -1043,6 +1207,11 @@ async def wait_closed(self) -> None: assert self._chan is not None await self._chan.wait_closed() + for task in self._cleanup_tasks: + await task + + self._cleanup_tasks = [] + class SSHClientProcess(SSHProcess[AnyStr], SSHClientStreamSession[AnyStr]): """SSH client process handler""" @@ -1057,20 +1226,6 @@ def __init__(self) -> None: self._stdout: Optional[SSHReader[AnyStr]] = None self._stderr: Optional[SSHReader[AnyStr]] = None - async def __aenter__(self) -> 'SSHClientProcess[AnyStr]': - """Allow SSHProcess to be used as an async context manager""" - - return self - - async def __aexit__(self, _exc_type: Optional[Type[BaseException]], - _exc_value: Optional[BaseException], - _traceback: Optional[TracebackType]) -> bool: - """Wait for a full channel close when exiting the async context""" - - self.close() - await self._chan.wait_closed() - return False - def _collect_output(self, datatype: DataType = None) -> AnyStr: """Return output from the process""" @@ -1130,11 +1285,22 @@ def stderr(self) -> SSHReader[AnyStr]: assert self._stderr is not None return self._stderr + def feed_exception(self, exc: Exception) -> None: + """Feed exception to the channel""" + + if isinstance(exc, TerminalSizeChanged): + self._chan.change_terminal_size(exc.width, exc.height, + exc.pixwidth, exc.pixheight) + elif isinstance(exc, BreakReceived): + self._chan.send_break(exc.msec) + elif isinstance(exc, SignalReceived): # pragma: no branch + self._chan.send_signal(exc.signal) + async def redirect(self, stdin: Optional[ProcessSource] = None, stdout: Optional[ProcessTarget] = None, stderr: Optional[ProcessTarget] = None, bufsize: int =io.DEFAULT_BUFFER_SIZE, - send_eof: bool = True) -> None: + send_eof: bool = True, recv_eof: bool = True) -> None: """Perform I/O redirection for the process This method redirects data going to or from any or all of @@ -1174,6 +1340,14 @@ async def redirect(self, stdin: Optional[ProcessSource] = None, The default value of `None` means to not change redirection for that stream. + .. note:: While it is legal to use buffered I/O streams such + as sys.stdin, sys.stdout, and sys.stderr as redirect + targets, you must make sure buffers are flushed + before redirection begins and that these streams + are put back into blocking mode before attempting + to go back using buffered I/O again. Also, no buffered + I/O should be performed while redirection is active. + .. note:: When passing in asyncio streams, it is the responsibility of the caller to close the associated transport when it is no longer needed. @@ -1187,44 +1361,52 @@ async def redirect(self, stdin: Optional[ProcessSource] = None, :param bufsize: Buffer size to use when forwarding data from a file :param send_eof: - Whether or not to send EOF to the channel when redirection - is complete, defaulting to `True`. If set to `False`, - multiple sources can be sequentially fed to the channel. + Whether or not to send EOF to the channel when EOF is + received from stdin, defaulting to `True`. If set to `False`, + the channel will remain open after EOF is received on stdin, + and multiple sources can be redirected to the channel. + :param recv_eof: + Whether or not to send EOF to stdout and stderr when EOF is + received from the channel, defaulting to `True`. If set to + `False`, the redirect targets of stdout and stderr will remain + open after EOF is received on the channel and can be used for + multiple redirects. :type bufsize: `int` :type send_eof: `bool` + :type recv_eof: `bool` """ if stdin: - await self._create_reader(stdin, bufsize, send_eof) + await self._create_reader(stdin, bufsize, send_eof, recv_eof) if stdout: - await self._create_writer(stdout, bufsize, send_eof) + await self._create_writer(stdout, bufsize, send_eof, recv_eof) if stderr: - await self._create_writer(stderr, bufsize, send_eof, + await self._create_writer(stderr, bufsize, send_eof, recv_eof, EXTENDED_DATA_STDERR) async def redirect_stdin(self, source: ProcessSource, bufsize: int = io.DEFAULT_BUFFER_SIZE, - send_eof : bool = True) -> None: + send_eof: bool = True) -> None: """Redirect standard input of the process""" - await self.redirect(source, None, None, bufsize, send_eof) + await self.redirect(source, None, None, bufsize, send_eof, True) async def redirect_stdout(self, target: ProcessTarget, bufsize: int = io.DEFAULT_BUFFER_SIZE, - send_eof: bool = True) -> None: + recv_eof: bool = True) -> None: """Redirect standard output of the process""" - await self.redirect(None, target, None, bufsize, send_eof) + await self.redirect(None, target, None, bufsize, True, recv_eof) async def redirect_stderr(self, target: ProcessTarget, bufsize: int = io.DEFAULT_BUFFER_SIZE, - send_eof: bool = True) -> None: + recv_eof: bool = True) -> None: """Redirect standard error of the process""" - await self.redirect(None, None, target, bufsize, send_eof) + await self.redirect(None, None, target, bufsize, True, recv_eof) def collect_output(self) -> Tuple[AnyStr, AnyStr]: """Collect output from the process without blocking @@ -1266,7 +1448,7 @@ async def communicate(self, input: Optional[AnyStr] = None) -> \ self._chan.write(input) self._chan.write_eof() - await self._chan.wait_closed() + await self.wait_closed() return self.collect_output() # pylint: enable=redefined-builtin @@ -1415,20 +1597,6 @@ def __init__(self, process_factory: SSHServerProcessFactory, self._stdout: Optional[SSHWriter[AnyStr]] = None self._stderr: Optional[SSHWriter[AnyStr]] = None - async def __aenter__(self) -> 'SSHServerProcess[AnyStr]': - """Allow SSHProcess to be used as an async context manager""" - - return self - - async def __aexit__(self, _exc_type: Optional[Type[BaseException]], - _exc_value: Optional[BaseException], - _traceback: Optional[TracebackType]) -> bool: - """Wait for a full channel close when exiting the async context""" - - self.close() - await self._chan.wait_closed() - return False - def _start_process(self, stdin: SSHReader[AnyStr], stdout: SSHWriter[AnyStr], stderr: SSHWriter[AnyStr]) -> MaybeAwait[None]: @@ -1497,11 +1665,21 @@ def stderr(self) -> SSHWriter[AnyStr]: assert self._stderr is not None return self._stderr + def exception_received(self, exc: Exception) -> None: + """Handle an incoming exception on the channel""" + + writer = self._writers.get(None) + + if writer: + writer.write_exception(exc) + else: + super().exception_received(exc) + async def redirect(self, stdin: Optional[ProcessTarget] = None, stdout: Optional[ProcessSource] = None, stderr: Optional[ProcessSource] = None, bufsize: int = io.DEFAULT_BUFFER_SIZE, - send_eof: bool = True) -> None: + send_eof: bool = True, recv_eof: bool = True) -> None: """Perform I/O redirection for the process This method redirects data going to or from any or all of @@ -1551,44 +1729,53 @@ async def redirect(self, stdin: Optional[ProcessTarget] = None, :param bufsize: Buffer size to use when forwarding data from a file :param send_eof: - Whether or not to send EOF to the channel when redirection - is complete, defaulting to `True`. If set to `False`, - multiple sources can be sequentially fed to the channel. + Whether or not to send EOF to the channel when EOF is + received from stdout or stderr, defaulting to `True`. If + set to `False`, the channel will remain open after EOF is + received on stdout or stderr, and multiple sources can be + redirected to the channel. + :param recv_eof: + Whether or not to send EOF to stdin when EOF is received + on the channel, defaulting to `True`. If set to `False`, + the redirect target of stdin will remain open after EOF + is received on the channel and can be used for multiple + redirects. :type bufsize: `int` :type send_eof: `bool` + :type recv_eof: `bool` """ if stdin: - await self._create_writer(stdin, bufsize, send_eof) + await self._create_writer(stdin, bufsize, send_eof, recv_eof) if stdout: - await self._create_reader(stdout, bufsize, send_eof) + await self._create_reader(stdout, bufsize, send_eof, recv_eof) if stderr: - await self._create_reader(stderr, bufsize, send_eof, + await self._create_reader(stderr, bufsize, send_eof, recv_eof, EXTENDED_DATA_STDERR) async def redirect_stdin(self, target: ProcessTarget, bufsize: int = io.DEFAULT_BUFFER_SIZE, - send_eof: bool = True) -> None: + recv_eof: bool = True) -> None: """Redirect standard input of the process""" - await self.redirect(target, None, None, bufsize, send_eof) + await self.redirect(target, None, None, bufsize, True, recv_eof) async def redirect_stdout(self, source: ProcessSource, bufsize: int = io.DEFAULT_BUFFER_SIZE, send_eof: bool = True) -> None: """Redirect standard output of the process""" - await self.redirect(None, source, None, bufsize, send_eof) + await self.redirect(None, source, None, bufsize, send_eof, True) async def redirect_stderr(self, source: ProcessSource, bufsize: int = io.DEFAULT_BUFFER_SIZE, send_eof: bool = True) -> None: """Redirect standard error of the process""" - await self.redirect(None, None, source, bufsize, send_eof) + await self.redirect(None, None, source, bufsize, send_eof, True) def get_terminal_type(self) -> Optional[str]: """Return the terminal type set by the client for the process diff --git a/asyncssh/public_key.py b/asyncssh/public_key.py index 4261be6..caa5561 100644 --- a/asyncssh/public_key.py +++ b/asyncssh/public_key.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -20,7 +20,9 @@ """SSH asymmetric encryption handlers""" +import asyncio import binascii +import inspect import os import re import time @@ -70,6 +72,7 @@ _PubKeyAlgMap = Dict[bytes, Type['SSHKey']] _CertAlgMap = Dict[bytes, Tuple[Optional[Type['SSHKey']], Type['SSHCertificate']]] +_CertSigAlgMap = Dict[bytes, bytes] _CertVersionMap = Dict[Tuple[bytes, int], Tuple[bytes, Type['SSHOpenSSHCertificate']]] @@ -94,6 +97,8 @@ _KeyPairArg = Union['SSHKeyPair', _KeyArg, Tuple[_KeyArg, _CertArg]] KeyPairListArg = Union[_KeyPairArg, Sequence[_KeyPairArg]] +_PassphraseCallable = Callable[[str], BytesOrStr] +_PassphraseArg = Optional[Union[_PassphraseCallable, BytesOrStr]] # Default file names in .ssh directory to read private keys from _DEFAULT_KEY_FILES = ( @@ -128,6 +133,7 @@ _public_key_alg_map: _PubKeyAlgMap = {} _certificate_alg_map: _CertAlgMap = {} +_certificate_sig_alg_map: _CertSigAlgMap = {} _certificate_version_map: _CertVersionMap = {} _pem_map: _PEMMap = {} _pkcs8_oid_map: _PKCS8OIDMap = {} @@ -188,6 +194,51 @@ def _wrap_base64(data: bytes, wrap: int = 64) -> bytes: for i in range(0, len(data), wrap)) + b'\n' +def _resolve_passphrase( + passphrase: _PassphraseArg, filename: str, + loop: Optional[asyncio.AbstractEventLoop]) -> Optional[BytesOrStr]: + """Resolve a passphrase used to encrypt/decrypt SSH private keys""" + + resolved_passphrase: Optional[BytesOrStr] + + if callable(passphrase): + resolved_passphrase = passphrase(filename) + else: + resolved_passphrase = passphrase + + if loop and inspect.isawaitable(resolved_passphrase): + resolved_passphrase = asyncio.run_coroutine_threadsafe( + resolved_passphrase, loop).result() + + return resolved_passphrase + + +class _EncryptedKey: + """Encrypted SSH private key, decrypted just prior to use""" + + def __init__(self, key_data: bytes, filename: str, + passphrase: _PassphraseArg, + loop: Optional[asyncio.AbstractEventLoop], + unsafe_skip_rsa_key_validation: bool): + self._key_data = key_data + self._filename = filename + self._passphrase = passphrase + self._loop = loop + self._unsafe_skip_rsa_key_validation = unsafe_skip_rsa_key_validation + + def decrypt(self) -> 'SSHKey': + """Decrypt this encrypted key data and return an SSH private key""" + + resolved_passphrase = _resolve_passphrase(self._passphrase, + self._filename, self._loop) + + key = import_private_key(self._key_data, resolved_passphrase, + self._unsafe_skip_rsa_key_validation) + key.set_filename(self._filename) + + return key + + class KeyGenerationError(ValueError): """Key generation error @@ -238,12 +289,14 @@ class SSHKey: algorithm: bytes = b'' sig_algorithms: Sequence[bytes] = () + cert_algorithms: Sequence[bytes] = () x509_algorithms: Sequence[bytes] = () all_sig_algorithms: Set[bytes] = set() - default_hash_name: str = '' + default_x509_hash: str = '' pem_name: bytes = b'' pkcs8_oid: Optional[ObjectIdentifier] = None use_executor: bool = False + use_webauthn: bool = False def __init__(self, key: Optional[CryptoKey] = None): self._key = key @@ -259,13 +312,13 @@ def generate(cls, algorithm: bytes, **kwargs) -> 'SSHKey': @classmethod def make_private(cls, key_params: object) -> 'SSHKey': - """Construct an RSA private key""" + """Construct a private key""" raise NotImplementedError @classmethod def make_public(cls, key_params: object) -> 'SSHKey': - """Construct an RSA public key""" + """Construct a public key""" raise NotImplementedError @@ -317,6 +370,7 @@ def _generate_certificate(self, key: 'SSHKey', version: int, serial: int, principals: _CertPrincipals, valid_after: _Time, valid_before: _Time, cert_options: _OpenSSHCertOptions, + sig_alg_name: DefTuple[str], comment: DefTuple[_Comment]) -> \ 'SSHOpenSSHCertificate': """Generate a new SSH certificate""" @@ -333,6 +387,11 @@ def _generate_certificate(self, key: 'SSHKey', version: int, serial: int, raise ValueError('Valid before time must be later than ' 'valid after time') + if sig_alg_name == (): + sig_alg = self.sig_algorithms[0] + else: + sig_alg = cast(str, sig_alg_name).encode() + if comment == (): comment = key.get_comment_bytes() @@ -346,7 +405,8 @@ def _generate_certificate(self, key: 'SSHKey', version: int, serial: int, return cert_handler.generate(self, algorithm, key, serial, cert_type, key_id, principals, valid_after, - valid_before, cert_options, comment) + valid_before, cert_options, + sig_alg, comment) def _generate_x509_certificate(self, key: 'SSHKey', subject: str, issuer: Optional[str], @@ -378,7 +438,7 @@ def _generate_x509_certificate(self, key: 'SSHKey', subject: str, 'valid after time') if hash_name == (): - hash_name = key.default_hash_name + hash_name = key.default_x509_hash if comment == (): comment = key.get_comment_bytes() @@ -626,24 +686,21 @@ def convert_to_public(self) -> 'SSHKey': result.set_filename(self._filename) return result - def generate_user_certificate(self, user_key: 'SSHKey', key_id: str, - version: int = 1, serial: int = 0, - principals: _CertPrincipals = (), - valid_after: _Time = 0, - valid_before: _Time = 0xffffffffffffffff, - force_command: str = None, - source_address: Sequence[str] = None, - permit_x11_forwarding: bool = True, - permit_agent_forwarding: bool = True, - permit_port_forwarding: bool = True, - permit_pty: bool = True, - permit_user_rc: bool = True, - touch_required: bool = True, - comment: DefTuple[_Comment] = ()) -> \ - 'SSHOpenSSHCertificate': + def generate_user_certificate( + self, user_key: 'SSHKey', key_id: str, version: int = 1, + serial: int = 0, principals: _CertPrincipals = (), + valid_after: _Time = 0, valid_before: _Time = 0xffffffffffffffff, + force_command: Optional[str] = None, + source_address: Optional[Sequence[str]] = None, + permit_x11_forwarding: bool = True, + permit_agent_forwarding: bool = True, + permit_port_forwarding: bool = True, permit_pty: bool = True, + permit_user_rc: bool = True, touch_required: bool = True, + sig_alg: DefTuple[str] = (), + comment: DefTuple[_Comment] = ()) -> 'SSHOpenSSHCertificate': """Generate a new SSH user certificate - This method returns an SSH user certifcate with the requested + This method returns an SSH user certificate with the requested attributes signed by this private key. :param user_key: @@ -689,6 +746,8 @@ def generate_user_certificate(self, user_key: 'SSHKey', key_id: str, :param touch_required: (optional) Whether or not to require the user to touch the security key when authenticating with it, defaulting to `True`. + :param sig_alg: (optional) + The algorithm to use when signing the new certificate. :param comment: The comment to associate with this certificate. By default, the comment will be set to the comment currently set on @@ -706,6 +765,7 @@ def generate_user_certificate(self, user_key: 'SSHKey', key_id: str, :type permit_pty: `bool` :type permit_user_rc: `bool` :type touch_required: `bool` + :type sig_alg: `str` :type comment: `str`, `bytes`, or `None` :returns: :class:`SSHCertificate` @@ -746,18 +806,20 @@ def generate_user_certificate(self, user_key: 'SSHKey', key_id: str, return self._generate_certificate(user_key, version, serial, CERT_TYPE_USER, key_id, principals, valid_after, - valid_before, cert_options, comment) + valid_before, cert_options, + sig_alg, comment) def generate_host_certificate(self, host_key: 'SSHKey', key_id: str, version: int = 1, serial: int = 0, principals: _CertPrincipals = (), valid_after: _Time = 0, valid_before: _Time = 0xffffffffffffffff, + sig_alg: DefTuple[str] = (), comment: DefTuple[_Comment] = ()) -> \ 'SSHOpenSSHCertificate': """Generate a new SSH host certificate - This method returns an SSH host certifcate with the requested + This method returns an SSH host certificate with the requested attributes signed by this private key. :param host_key: @@ -779,6 +841,8 @@ def generate_host_certificate(self, host_key: 'SSHKey', key_id: str, The latest time the certificate is valid for, defaulting to no restriction on when the certificate stops being valid. See :ref:`SpecifyingTimeValues` for allowed time specifications. + :param sig_alg: (optional) + The algorithm to use when signing the new certificate. :param comment: The comment to associate with this certificate. By default, the comment will be set to the comment currently set on @@ -788,6 +852,7 @@ def generate_host_certificate(self, host_key: 'SSHKey', key_id: str, :type version: `int` :type serial: `int` :type principals: `str` or `list` of `str` + :type sig_alg: `str` :type comment: `str`, `bytes`, or `None` :returns: :class:`SSHCertificate` @@ -803,21 +868,19 @@ def generate_host_certificate(self, host_key: 'SSHKey', key_id: str, return self._generate_certificate(host_key, version, serial, CERT_TYPE_HOST, key_id, principals, valid_after, - valid_before, {}, comment) - - def generate_x509_user_certificate(self, user_key: 'SSHKey', subject: str, - issuer: str = None, serial: int = None, - principals: _CertPrincipals = (), - valid_after: _Time = 0, - valid_before: _Time = 0xffffffffffffffff, - purposes: X509CertPurposes = \ - 'secureShellClient', - hash_alg: DefTuple[str] = (), - comment: DefTuple[_Comment] = ()) -> \ - 'SSHX509Certificate': + valid_before, {}, sig_alg, comment) + + def generate_x509_user_certificate( + self, user_key: 'SSHKey', subject: str, + issuer: Optional[str] = None, serial: Optional[int] = None, + principals: _CertPrincipals = (), valid_after: _Time = 0, + valid_before: _Time = 0xffffffffffffffff, + purposes: X509CertPurposes = 'secureShellClient', + hash_alg: DefTuple[str] = (), + comment: DefTuple[_Comment] = ()) -> 'SSHX509Certificate': """Generate a new X.509 user certificate - This method returns an X.509 user certifcate with the requested + This method returns an X.509 user certificate with the requested attributes signed by this private key. :param user_key: @@ -878,19 +941,17 @@ def generate_x509_user_certificate(self, user_key: 'SSHKey', subject: str, purposes, principals, (), hash_alg, comment) - def generate_x509_host_certificate(self, host_key: 'SSHKey', subject: str, - issuer: str = None, serial: int = None, - principals: _CertPrincipals = (), - valid_after: _Time = 0, - valid_before: _Time = 0xffffffffffffffff, - purposes: X509CertPurposes = \ - 'secureShellServer', - hash_alg: DefTuple[str] = (), - comment: DefTuple[_Comment] = ()) -> \ - 'SSHX509Certificate': + def generate_x509_host_certificate( + self, host_key: 'SSHKey', subject: str, + issuer: Optional[str] = None, serial: Optional[int] = None, + principals: _CertPrincipals = (), valid_after: _Time = 0, + valid_before: _Time = 0xffffffffffffffff, + purposes: X509CertPurposes = 'secureShellServer', + hash_alg: DefTuple[str] = (), + comment: DefTuple[_Comment] = ()) -> 'SSHX509Certificate': """Generate a new X.509 host certificate - This method returns a X.509 host certifcate with the requested + This method returns an X.509 host certificate with the requested attributes signed by this private key. :param host_key: @@ -961,7 +1022,7 @@ def generate_x509_ca_certificate(self, ca_key: 'SSHKey', subject: str, 'SSHX509Certificate': """Generate a new X.509 CA certificate - This method returns a X.509 CA certifcate with the requested + This method returns an X.509 CA certificate with the requested attributes signed by this private key. :param ca_key: @@ -1158,7 +1219,7 @@ def export_private_key(self, format_name: str = 'openssh', key_size, iv_size, block_size, _, _, _ = \ get_encryption_params(alg) except (KeyError, UnicodeEncodeError): - raise KeyEncryptionError('Unknown cipher: %s' % + raise KeyEncryptionError('Unknown cipher: ' + cipher_name) from None if not _bcrypt_available: # pragma: no cover @@ -1360,6 +1421,8 @@ def construct(cls, packet: SSHPacket, algorithm: bytes, comment: _Comment) -> 'SSHCertificate': """Construct an SSH certificate from packetized data""" + raise NotImplementedError + def __eq__(self, other: object) -> bool: return (isinstance(other, type(self)) and self.public_data == other.public_data) @@ -1547,7 +1610,8 @@ def __init__(self, algorithm: bytes, key: SSHKey, data: bytes, signing_key: SSHKey, serial: int, cert_type: int, key_id: str, valid_after: int, valid_before: int, comment: _Comment): - super().__init__(algorithm, key.sig_algorithms, (algorithm,), + super().__init__(algorithm, key.sig_algorithms, + key.cert_algorithms or (algorithm,), key, data, comment) self.principals = principals @@ -1565,7 +1629,7 @@ def generate(cls, signing_key: 'SSHKey', algorithm: bytes, key: 'SSHKey', serial: int, cert_type: int, key_id: str, principals: Sequence[str], valid_after: int, valid_before: int, options: _OpenSSHCertOptions, - comment: _Comment) -> 'SSHOpenSSHCertificate': + sig_alg: bytes, comment: _Comment) -> 'SSHOpenSSHCertificate': """Generate a new SSH certificate""" principal_bytes = b''.join(String(p) for p in principals) @@ -1590,7 +1654,7 @@ def generate(cls, signing_key: 'SSHKey', algorithm: bytes, key: 'SSHKey', cert_extensions), String(signing_key.public_data))) - data += String(signing_key.sign(data, signing_key.algorithm)) + data += String(signing_key.sign(data, sig_alg)) signing_key = signing_key.convert_to_public() @@ -1601,7 +1665,7 @@ def generate(cls, signing_key: 'SSHKey', algorithm: bytes, key: 'SSHKey', @classmethod def construct(cls, packet: SSHPacket, algorithm: bytes, key_handler: Optional[Type[SSHKey]], - comment: _Comment) -> 'SSHCertificate': + comment: _Comment) -> 'SSHOpenSSHCertificate': """Construct an SSH certificate from packetized data""" assert key_handler is not None @@ -1746,12 +1810,12 @@ def _decode_options(options: bytes, decoders: _OpenSSHCertDecoders, result[name.decode('ascii')] = decoder(data_packet) data_packet.check_end() elif critical: - raise KeyImportError('Unrecognized critical option: %s' % + raise KeyImportError('Unrecognized critical option: ' + name.decode('ascii', errors='replace')) return result - def validate(self, cert_type: int, principal: str) -> None: + def validate(self, cert_type: int, principal: Optional[str]) -> None: """Validate an OpenSSH certificate""" if self._cert_type != cert_type: @@ -1765,7 +1829,8 @@ def validate(self, cert_type: int, principal: str) -> None: if now >= self._valid_before: raise ValueError('Certificate expired') - if principal and self.principals and principal not in self.principals: + if principal is not None and self.principals and \ + principal not in self.principals: raise ValueError('Certificate principal mismatch') @@ -1875,6 +1940,14 @@ def _expand_trust_store(self, cert: 'SSHX509Certificate', except (OSError, KeyImportError): pass + @classmethod + def construct(cls, packet: SSHPacket, algorithm: bytes, + key_handler: Optional[Type[SSHKey]], + comment: _Comment) -> 'SSHX509Certificate': + """Construct an SSH X.509 certificate from packetized data""" + + raise RuntimeError # pragma: no cover + @classmethod def generate(cls, signing_key: 'SSHKey', key: 'SSHKey', subject: str, issuer: Optional[str], serial: Optional[int], @@ -1916,8 +1989,8 @@ def validate_chain(self, trust_chain: Sequence['SSHX509Certificate'], host_principal: str = '') -> None: """Validate an X.509 certificate chain""" - trust_store = set(c for c in trust_chain if c.subject != c.issuer) | \ - set(c for c in trusted_certs) + trust_store = {c for c in trust_chain if c.subject != c.issuer} | \ + set(trusted_certs) if trusted_cert_paths: self._expand_trust_store(self, trusted_cert_paths, trust_store) @@ -2031,8 +2104,10 @@ def __init__(self, algorithm: bytes, sig_algorithm: bytes, sig_algorithms: Sequence[bytes], host_key_algorithms: Sequence[bytes], public_data: bytes, comment: _Comment, - cert: SSHCertificate = None, filename: bytes = None, - use_executor: bool = False): + cert: Optional[SSHCertificate] = None, + filename: Optional[bytes] = None, + use_executor: bool = False, + use_webauthn: bool = False): self.key_algorithm = algorithm self.key_public_data = public_data @@ -2041,6 +2116,7 @@ def __init__(self, algorithm: bytes, sig_algorithm: bytes, self._filename = filename self.use_executor = use_executor + self.use_webauthn = use_webauthn if cert: if cert.key.public_data != self.key_public_data: @@ -2073,6 +2149,18 @@ def get_key_type(self) -> str: return self._key_type + @property + def has_cert(self) -> bool: + """ Return if this key pair has an associated cert""" + + return bool(self._cert) + + @property + def has_x509_chain(self) -> bool: + """ Return if this key pair has an associated X.509 cert chain""" + + return self._cert.is_x509_chain if self._cert else False + def get_algorithm(self) -> str: """Return the algorithm associated with this key pair""" @@ -2164,11 +2252,16 @@ def set_certificate(self, cert: SSHCertificate) -> None: def set_sig_algorithm(self, sig_algorithm: bytes) -> None: """Set the signature algorithm to use when signing data""" + try: + sig_algorithm = _certificate_sig_alg_map[sig_algorithm] + except KeyError: + pass + self.sig_algorithm = sig_algorithm - if not self._cert: + if not self.has_cert: self.algorithm = sig_algorithm - elif self._cert.is_x509_chain: + elif self.has_x509_chain: self.algorithm = sig_algorithm cert = cast('SSHX509CertificateChain', self._cert) @@ -2177,6 +2270,9 @@ def set_sig_algorithm(self, sig_algorithm: bytes) -> None: def sign(self, data: bytes) -> bytes: """Sign a block of data with this private key""" + # pylint: disable=no-self-use + raise RuntimeError # pragma: no cover + class SSHLocalKeyPair(SSHKeyPair): """Class which holds a local asymmetric key pair @@ -2189,8 +2285,9 @@ class SSHLocalKeyPair(SSHKeyPair): _key_type = 'local' - def __init__(self, key: SSHKey, pubkey: SSHKey = None, - cert: SSHCertificate = None): + def __init__(self, key: SSHKey, pubkey: Optional[SSHKey], + cert: Optional[SSHCertificate], + enc_key: Optional[_EncryptedKey]): if pubkey and pubkey.public_data != key.public_data: raise ValueError('Public key mismatch') @@ -2204,10 +2301,12 @@ def __init__(self, key: SSHKey, pubkey: SSHKey = None, comment = None super().__init__(key.algorithm, key.algorithm, key.sig_algorithms, - key.sig_algorithms, key.public_data, comment, cert, - key.get_filename(), key.use_executor) + key.sig_algorithms, key.public_data, comment, + cert, key.get_filename(), key.use_executor or + bool(enc_key), key.use_webauthn) self._key = key + self._enc_key = enc_key def get_agent_private_key(self) -> bytes: """Return binary encoding of keypair for upload to SSH agent""" @@ -2223,6 +2322,12 @@ def get_agent_private_key(self) -> bytes: def sign(self, data: bytes) -> bytes: """Sign a block of data with this private key""" + if self._enc_key: + self._key = self._enc_key.decrypt() + self._enc_key = None + + self.use_executor = self._key.use_executor + return self._key.sign(data, self.sig_algorithm) @@ -2318,10 +2423,10 @@ def _match_block(data: bytes, start: int, header: bytes, """Match a block of data wrapped in a header/footer""" match = re.compile(b'^' + header[:5] + b'END' + header[10:] + - rb'[ \t\r\f\v]*$', re.M).search(data, start) + rb'[ \t\n\r\f\v]*$', re.M).search(data, start) if not match: - raise KeyImportError('Missing %s footer' % fmt) + raise KeyImportError(f'Missing {fmt} footer') return data[start:match.start()], match.end() @@ -2371,18 +2476,24 @@ def _match_next(data: bytes, keytype: bytes, public: bool = False) -> \ return None, (), len(data) -def _decode_pkcs1_private(pem_name: bytes, key_data: object) -> SSHKey: +def _decode_pkcs1_private( + pem_name: bytes, key_data: object, + unsafe_skip_rsa_key_validation: Optional[bool]) -> SSHKey: """Decode a PKCS#1 format private key""" handler = _pem_map.get(pem_name) if handler is None: - raise KeyImportError('Unknown PEM key type: %s' % + raise KeyImportError('Unknown PEM key type: ' + pem_name.decode('ascii')) key_params = handler.decode_pkcs1_private(key_data) if key_params is None: - raise KeyImportError('Invalid %s private key' % - pem_name.decode('ascii')) + raise KeyImportError( + f'Invalid {pem_name.decode("ascii")} private key') + + if pem_name == b'RSA': + key_params = cast(Tuple, key_params) + \ + (unsafe_skip_rsa_key_validation,) return handler.make_private(key_params) @@ -2392,18 +2503,19 @@ def _decode_pkcs1_public(pem_name: bytes, key_data: object) -> SSHKey: handler = _pem_map.get(pem_name) if handler is None: - raise KeyImportError('Unknown PEM key type: %s' % + raise KeyImportError('Unknown PEM key type: ' + pem_name.decode('ascii')) key_params = handler.decode_pkcs1_public(key_data) if key_params is None: - raise KeyImportError('Invalid %s public key' % - pem_name.decode('ascii')) + raise KeyImportError(f'Invalid {pem_name.decode("ascii")} public key') return handler.make_public(key_params) -def _decode_pkcs8_private(key_data: object) -> SSHKey: +def _decode_pkcs8_private( + key_data: object, + unsafe_skip_rsa_key_validation: Optional[bool]) -> SSHKey: """Decode a PKCS#8 format private key""" if (isinstance(key_data, tuple) and len(key_data) >= 3 and @@ -2420,9 +2532,13 @@ def _decode_pkcs8_private(key_data: object) -> SSHKey: key_params = handler.decode_pkcs8_private(alg_params, key_data[2]) if key_params is None: - raise KeyImportError('Invalid %s private key' % - handler.pem_name.decode('ascii') - if handler.pem_name else 'PKCS#8') + key_type = handler.pem_name.decode('ascii') if \ + handler.pem_name else 'PKCS#8' + raise KeyImportError(f'Invalid {key_type} private key') + + if alg == ObjectIdentifier('1.2.840.113549.1.1.1'): + key_params = cast(Tuple, key_params) + \ + (unsafe_skip_rsa_key_validation,) return handler.make_private(key_params) else: @@ -2446,17 +2562,18 @@ def _decode_pkcs8_public(key_data: object) -> SSHKey: key_params = handler.decode_pkcs8_public(alg_params, key_data[1].value) if key_params is None: - raise KeyImportError('Invalid %s public key' % - handler.pem_name.decode('ascii') - if handler.pem_name else 'PKCS#8') + key_type = handler.pem_name.decode('ascii') if \ + handler.pem_name else 'PKCS#8' + raise KeyImportError(f'Invalid {key_type} public key') return handler.make_public(key_params) else: raise KeyImportError('Invalid PKCS#8 public key') -def _decode_openssh_private(data: bytes, - passphrase: Optional[BytesOrStr]) -> SSHKey: +def _decode_openssh_private( + data: bytes, passphrase: Optional[BytesOrStr], + unsafe_skip_rsa_key_validation: Optional[bool]) -> SSHKey: """Decode an OpenSSH format private key""" try: @@ -2483,15 +2600,14 @@ def _decode_openssh_private(data: bytes, 'encrypted private keys') try: - key_size, iv_size, block_size, _, _, _ = \ + key_size, iv_size, _, _, _, _ = \ get_encryption_params(cipher_name) except KeyError: - raise KeyEncryptionError('Unknown cipher: %s' % + raise KeyEncryptionError('Unknown cipher: ' + cipher_name.decode('ascii')) from None if kdf != b'bcrypt': - raise KeyEncryptionError('Unknown kdf: %s' % - kdf.decode('ascii')) + raise KeyEncryptionError('Unknown kdf: ' + kdf.decode('ascii')) if not _bcrypt_available: # pragma: no cover raise KeyEncryptionError('OpenSSH private key encryption ' @@ -2521,9 +2637,6 @@ def _decode_openssh_private(data: bytes, raise KeyEncryptionError('Incorrect passphrase') key_data = decrypted_key - block_size = max(block_size, 8) - else: - block_size = 8 packet = SSHPacket(key_data) @@ -2544,9 +2657,13 @@ def _decode_openssh_private(data: bytes, comment = packet.get_string() pad = packet.get_remaining_payload() - if len(pad) >= block_size or pad != bytes(range(1, len(pad) + 1)): + if len(pad) >= 256 or pad != bytes(range(1, len(pad) + 1)): raise KeyImportError('Invalid OpenSSH private key') + if alg == b'ssh-rsa': + key_params = cast(Tuple, key_params) + \ + (unsafe_skip_rsa_key_validation,) + key = handler.make_private(key_params) key.set_comment(comment) return key @@ -2578,8 +2695,9 @@ def _decode_openssh_public(data: bytes) -> SSHKey: raise KeyImportError('Invalid OpenSSH private key') from None -def _decode_der_private(key_data: object, - passphrase: Optional[BytesOrStr]) -> SSHKey: +def _decode_der_private( + key_data: object, passphrase: Optional[BytesOrStr], + unsafe_skip_rsa_key_validation: Optional[bool]) -> SSHKey: """Decode a DER format private key""" # First, if there's a passphrase, try to decrypt PKCS#8 @@ -2592,7 +2710,7 @@ def _decode_der_private(key_data: object, # Then, try to decode PKCS#8 try: - return _decode_pkcs8_private(key_data) + return _decode_pkcs8_private(key_data, unsafe_skip_rsa_key_validation) except KeyImportError: # PKCS#8 failed - try PKCS#1 instead pass @@ -2600,7 +2718,8 @@ def _decode_der_private(key_data: object, # If that fails, try each of the possible PKCS#1 encodings for pem_name in _pem_map: try: - return _decode_pkcs1_private(pem_name, key_data) + return _decode_pkcs1_private(pem_name, key_data, + unsafe_skip_rsa_key_validation) except KeyImportError: # Try the next PKCS#1 encoding pass @@ -2636,13 +2755,15 @@ def _decode_der_certificate(data: bytes, return SSHX509Certificate.construct_from_der(data, comment) -def _decode_pem_private(pem_name: bytes, headers: Mapping[bytes, bytes], - data: bytes, passphrase: Optional[BytesOrStr]) -> \ - SSHKey: +def _decode_pem_private( + pem_name: bytes, headers: Mapping[bytes, bytes], + data: bytes, passphrase: Optional[BytesOrStr], + unsafe_skip_rsa_key_validation: Optional[bool]) -> SSHKey: """Decode a PEM format private key""" if pem_name == b'OPENSSH': - return _decode_openssh_private(data, passphrase) + return _decode_openssh_private(data, passphrase, + unsafe_skip_rsa_key_validation) if headers.get(b'Proc-Type') == b'4,ENCRYPTED': if passphrase is None: @@ -2684,9 +2805,10 @@ def _decode_pem_private(pem_name: bytes, headers: Mapping[bytes, bytes], 'private key') from None if pem_name: - return _decode_pkcs1_private(pem_name, key_data) + return _decode_pkcs1_private(pem_name, key_data, + unsafe_skip_rsa_key_validation) else: - return _decode_pkcs8_private(key_data) + return _decode_pkcs8_private(key_data, unsafe_skip_rsa_key_validation) def _decode_pem_public(pem_name: bytes, data: bytes) -> SSHKey: @@ -2719,8 +2841,10 @@ def _decode_pem_certificate(pem_name: bytes, data: bytes) -> SSHCertificate: return SSHX509Certificate.construct_from_der(data) -def _decode_private(data: bytes, passphrase: Optional[BytesOrStr]) -> \ - Tuple[Optional[SSHKey], Optional[int]]: +def _decode_private( + data: bytes, passphrase: Optional[BytesOrStr], + unsafe_skip_rsa_key_validation: Optional[bool]) -> \ + Tuple[Optional[SSHKey], Optional[int]]: """Decode a private key""" fmt, key_info, end = _match_next(data, b'PRIVATE KEY') @@ -2728,10 +2852,12 @@ def _decode_private(data: bytes, passphrase: Optional[BytesOrStr]) -> \ key: Optional[SSHKey] if fmt == 'der': - key = _decode_der_private(key_info[0], passphrase) + key = _decode_der_private(key_info[0], passphrase, + unsafe_skip_rsa_key_validation) elif fmt == 'pem': pem_name, headers, data = key_info - key = _decode_pem_private(pem_name, headers, data, passphrase) + key = _decode_pem_private(pem_name, headers, data, passphrase, + unsafe_skip_rsa_key_validation) else: key = None @@ -2768,7 +2894,7 @@ def _decode_public(data: bytes) -> Tuple[Optional[SSHKey], Optional[int]]: if fmt == 'pem' and key_info[0] == b'OPENSSH': key = _decode_openssh_public(key_info[2]) else: - key, _ = _decode_private(data, None) + key, _ = _decode_private(data, None, False) if key: key = key.convert_to_public() @@ -2805,14 +2931,16 @@ def _decode_certificate(data: bytes) -> \ return cert, end -def _decode_private_list(data: bytes, passphrase: Optional[BytesOrStr]) -> \ - Sequence[SSHKey]: +def _decode_private_list( + data: bytes, passphrase: Optional[BytesOrStr], + unsafe_skip_rsa_key_validation: Optional[bool]) -> Sequence[SSHKey]: """Decode a private key list""" keys: List[SSHKey] = [] while data: - key, end = _decode_private(data, passphrase) + key, end = _decode_private(data, passphrase, + unsafe_skip_rsa_key_validation) if key: keys.append(key) @@ -2862,7 +2990,8 @@ def register_sk_alg(sk_alg: int, handler: Type[SSHKey], *args: object) -> None: def register_public_key_alg(algorithm: bytes, handler: Type[SSHKey], default: bool, - sig_algorithms: Sequence[bytes] = None) -> None: + sig_algorithms: Optional[Sequence[bytes]] = \ + None) -> None: """Register a new public key algorithm""" if not sig_algorithms: @@ -2896,6 +3025,8 @@ def register_certificate_alg(version: int, algorithm: bytes, _certificate_alg_map[cert_algorithm] = (key_handler, cert_handler) + _certificate_sig_alg_map[cert_algorithm] = algorithm + _certificate_version_map[algorithm, version] = \ (cert_algorithm, cert_handler) @@ -2964,7 +3095,7 @@ def decode_ssh_public_key(data: bytes) -> SSHKey: key.algorithm = alg return key else: - raise KeyImportError('Unknown key algorithm: %s' % + raise KeyImportError('Unknown key algorithm: ' + alg.decode('ascii', errors='replace')) except PacketDecodeError: raise KeyImportError('Invalid public key') from None @@ -2982,7 +3113,7 @@ def decode_ssh_certificate(data: bytes, if cert_handler: return cert_handler.construct(packet, alg, key_handler, comment) else: - raise KeyImportError('Unknown certificate algorithm: %s' % + raise KeyImportError('Unknown certificate algorithm: ' + alg.decode('ascii', errors='replace')) except (PacketDecodeError, ValueError): raise KeyImportError('Invalid OpenSSH certificate') from None @@ -3083,13 +3214,14 @@ def generate_private_key(alg_name: str, comment: _Comment = None, except (TypeError, ValueError) as exc: raise KeyGenerationError(str(exc)) from None else: - raise KeyGenerationError('Unknown algorithm: %s' % alg_name) + raise KeyGenerationError('Unknown algorithm: ' + alg_name) key.set_comment(comment) return key -def import_private_key(data: BytesOrStr, - passphrase: Optional[BytesOrStr] = None) -> SSHKey: +def import_private_key( + data: BytesOrStr, passphrase: Optional[BytesOrStr] = None, + unsafe_skip_rsa_key_validation: Optional[bool] = None) -> SSHKey: """Import a private key This function imports a private key encoded in PKCS#1 or PKCS#8 DER @@ -3100,8 +3232,13 @@ def import_private_key(data: BytesOrStr, The data to import. :param passphrase: (optional) The passphrase to use to decrypt the key. + :param unsafe_skip_rsa_key_validation: (optional) + Whether or not to skip key validation when loading RSA private + keys, defaulting to performing these checks unless changed by + calling :func:`set_default_skip_rsa_key_validation`. :type data: `bytes` or ASCII `str` :type passphrase: `str` or `bytes` + :type unsafe_skip_rsa_key_validation: bool :returns: An :class:`SSHKey` private key @@ -3113,7 +3250,7 @@ def import_private_key(data: BytesOrStr, except UnicodeEncodeError: raise KeyImportError('Invalid encoding for key') from None - key, _ = _decode_private(data, passphrase) + key, _ = _decode_private(data, passphrase, unsafe_skip_rsa_key_validation) if key: return key @@ -3121,19 +3258,6 @@ def import_private_key(data: BytesOrStr, raise KeyImportError('Invalid private key') -def import_private_key_and_certs(data: bytes, - passphrase: Optional[BytesOrStr] = None) -> \ - Tuple[SSHKey, Optional[SSHX509CertificateChain]]: - """Import a private key and optional certificate chain""" - - key, end = _decode_private(data, passphrase) - - if key: - return key, import_certificate_chain(data[end:]) - else: - raise KeyImportError('Invalid private key') - - def import_public_key(data: BytesOrStr) -> SSHKey: """Import a public key @@ -3222,8 +3346,9 @@ def import_certificate_subject(data: str) -> str: raise KeyImportError('Invalid certificate subject') -def read_private_key(filename: FilePath, - passphrase: Optional[BytesOrStr] = None) -> SSHKey: +def read_private_key( + filename: FilePath, passphrase: Optional[BytesOrStr] = None, + unsafe_skip_rsa_key_validation: Optional[bool] = None) -> SSHKey: """Read a private key from a file This function reads a private key from a file. See the function @@ -3234,32 +3359,26 @@ def read_private_key(filename: FilePath, The file to read the key from. :param passphrase: (optional) The passphrase to use to decrypt the key. + :param unsafe_skip_rsa_key_validation: (optional) + Whether or not to skip key validation when loading RSA private + keys, defaulting to performing these checks unless changed by + calling :func:`set_default_skip_rsa_key_validation`. :type filename: :class:`PurePath ` or `str` :type passphrase: `str` or `bytes` + :type unsafe_skip_rsa_key_validation: bool :returns: An :class:`SSHKey` private key """ - key = import_private_key(read_file(filename), passphrase) + key = import_private_key(read_file(filename), passphrase, + unsafe_skip_rsa_key_validation) key.set_filename(filename) return key -def read_private_key_and_certs(filename: FilePath, - passphrase: Optional[BytesOrStr] = None) -> \ - Tuple[SSHKey, Optional[SSHX509CertificateChain]]: - """Read a private key and optional certificate chain from a file""" - - key, cert = import_private_key_and_certs(read_file(filename), passphrase) - - key.set_filename(filename) - - return key, cert - - def read_public_key(filename: FilePath) -> SSHKey: """Read a public key from a file @@ -3300,9 +3419,10 @@ def read_certificate(filename: FilePath) -> SSHCertificate: return import_certificate(read_file(filename)) -def read_private_key_list(filename: FilePath, - passphrase: Optional[BytesOrStr] = None) -> \ - Sequence[SSHKey]: +def read_private_key_list( + filename: FilePath, passphrase: Optional[BytesOrStr] = None, + unsafe_skip_rsa_key_validation: Optional[bool] = None) -> \ + Sequence[SSHKey]: """Read a list of private keys from a file This function reads a list of private keys from a file. See the @@ -3314,14 +3434,20 @@ def read_private_key_list(filename: FilePath, The file to read the keys from. :param passphrase: (optional) The passphrase to use to decrypt the keys. + :param unsafe_skip_rsa_key_validation: (optional) + Whether or not to skip key validation when loading RSA private + keys, defaulting to performing these checks unless changed by + calling :func:`set_default_skip_rsa_key_validation`. :type filename: :class:`PurePath ` or `str` :type passphrase: `str` or `bytes` + :type unsafe_skip_rsa_key_validation: bool :returns: A list of :class:`SSHKey` private keys """ - keys = _decode_private_list(read_file(filename), passphrase) + keys = _decode_private_list(read_file(filename), passphrase, + unsafe_skip_rsa_key_validation) for key in keys: key.set_filename(filename) @@ -3370,10 +3496,13 @@ def read_certificate_list(filename: FilePath) -> Sequence[SSHCertificate]: return _decode_certificate_list(read_file(filename)) -def load_keypairs(keylist: KeyPairListArg, - passphrase: Optional[BytesOrStr] = None, - certlist: CertListArg = (), skip_public: bool = False, - ignore_encrypted: bool = False) -> Sequence[SSHKeyPair]: +def load_keypairs( + keylist: KeyPairListArg, passphrase: Optional[BytesOrStr] = None, + certlist: CertListArg = (), skip_public: bool = False, + ignore_encrypted: bool = False, + unsafe_skip_rsa_key_validation: Optional[bool] = None, + loop: Optional[asyncio.AbstractEventLoop] = None) -> \ + Sequence[SSHKeyPair]: """Load SSH private keys and optional matching certificates This function loads a list of SSH keys and optional matching @@ -3385,7 +3514,8 @@ def load_keypairs(keylist: KeyPairListArg, :param keylist: The list of private keys and certificates to load. :param passphrase: (optional) - The passphrase to use to decrypt private keys. + The passphrase to use to decrypt the keys, or a `callable` which + takes a filename and returns the passphrase to decrypt it. :param certlist: (optional) A list of certificates to attempt to pair with the provided list of private keys. @@ -3393,26 +3523,52 @@ def load_keypairs(keylist: KeyPairListArg, An internal parameter used to skip public keys and certificates when IdentitiesOnly and IdentityFile are used to specify a mixture of private and public keys. + :param unsafe_skip_rsa_key_validation: (optional) + Whether or not to skip key validation when loading RSA private + keys, defaulting to performing these checks unless changed by + calling :func:`set_default_skip_rsa_key_validation`. :type keylist: *see* :ref:`SpecifyingPrivateKeys` :type passphrase: `str` or `bytes` :type certlist: *see* :ref:`SpecifyingCertificates` :type skip_public: `bool` + :type unsafe_skip_rsa_key_validation: bool :returns: A list of :class:`SSHKeyPair` objects """ keys_to_load: Sequence[_KeyPairArg] + key_data: Optional[bytes] + key: Union['SSHKey', 'SSHKeyPair'] result: List[SSHKeyPair] = [] certlist = load_certificates(certlist) certdict = {cert.key.public_data: cert for cert in certlist} if isinstance(keylist, (PurePath, str)): - try: - priv_keys = read_private_key_list(keylist, passphrase) - keys_to_load = [keylist] if len(priv_keys) <= 1 else priv_keys - except KeyImportError: + data = read_file(keylist) + key_data_list: List[bytes] = [] + + while data: + fmt, _, end = _match_next(data, b'PRIVATE KEY') + if fmt: + key_data_list.append(data[:end]) + + data = data[end:] + + if len(key_data_list) > 1: + resolved_passphrase = _resolve_passphrase(passphrase, + str(keylist), loop) + + keys_to_load = [] + + for key_data in key_data_list: + key = import_private_key(key_data, resolved_passphrase, + unsafe_skip_rsa_key_validation) + key.set_filename(keylist) + + keys_to_load.append(key) + else: keys_to_load = [keylist] elif isinstance(keylist, (tuple, bytes, SSHKey, SSHKeyPair)): keys_to_load = [cast(_KeyPairArg, keylist)] @@ -3421,48 +3577,37 @@ def load_keypairs(keylist: KeyPairListArg, for key_to_load in keys_to_load: allow_certs = False - key_prefix = None - saved_exc = None + key_data = None + key_prefix = '' pubkey_or_certs = None - pubkey_to_load: Optional[_KeyArg] = None certs_to_load: Optional[_CertArg] = None - key: Union['SSHKey', 'SSHKeyPair'] + pubkey_to_load: Optional[_KeyArg] = None + saved_exc = None + enc_key: Optional[_EncryptedKey] = None if isinstance(key_to_load, (PurePath, str, bytes)): allow_certs = True elif isinstance(key_to_load, tuple): key_to_load, pubkey_or_certs = key_to_load - try: - if isinstance(key_to_load, (PurePath, str)): - key_prefix = str(key_to_load) + if isinstance(key_to_load, (PurePath, str)): + key_prefix = str(key_to_load) + key_data = read_file(key_to_load) + elif isinstance(key_to_load, bytes): + key_data = key_to_load - if allow_certs: - key, certs_to_load = \ - read_private_key_and_certs(key_to_load, passphrase) + certs: Optional[Sequence[SSHCertificate]] - if not certs_to_load: - certs_to_load = key_prefix + '-cert.pub' - else: - key = read_private_key(key_to_load, passphrase) + if allow_certs: + assert key_data is not None - pubkey_to_load = key_prefix + '.pub' - elif isinstance(key_to_load, bytes): - if allow_certs: - key, certs_to_load = \ - import_private_key_and_certs(key_to_load, passphrase) - else: - key = import_private_key(key_to_load, passphrase) - else: - key = key_to_load - except KeyImportError as exc: - if skip_public or \ - (ignore_encrypted and str(exc).startswith('Passphrase')): - continue + _, _, end = _match_next(key_data, b'PRIVATE KEY') - raise + certs_to_load = import_certificate_chain(key_data[end:]) + key_data = key_data[:end] - certs: Optional[Sequence[SSHCertificate]] + if not certs_to_load: + certs_to_load = key_prefix + '-cert.pub' if pubkey_or_certs: try: @@ -3476,7 +3621,7 @@ def load_keypairs(keylist: KeyPairListArg, elif certs_to_load: try: certs = load_certificates(certs_to_load) - except (OSError, KeyImportError): + except (OSError, KeyImportError) as exc: certs = None else: certs = None @@ -3491,16 +3636,58 @@ def load_keypairs(keylist: KeyPairListArg, pubkey = import_public_key(pubkey_to_load) else: pubkey = pubkey_to_load + + saved_exc = None except (OSError, KeyImportError): pubkey = None - else: + elif key_prefix: + try: + pubkey = read_public_key(key_prefix + '.pub') saved_exc = None + except (OSError, KeyImportError): + try: + pubkey = read_public_key(key_prefix) + saved_exc = None + except (OSError, KeyImportError): + pubkey = None else: pubkey = None if saved_exc: raise saved_exc # pylint: disable=raising-bad-type + if key_data is not None: + try: + unencrypted_key = import_private_key( + key_data, None, unsafe_skip_rsa_key_validation) + unencrypted_key.set_filename(key_prefix) + except KeyImportError: + unencrypted_key = None + + if unencrypted_key: + key = unencrypted_key + elif callable(passphrase) and key_prefix and (certs or pubkey): + enc_key = _EncryptedKey(key_data, key_prefix, passphrase, loop, + unsafe_skip_rsa_key_validation) + + key = certs[0].key if certs else pubkey + else: + try: + resolved_passphrase = _resolve_passphrase(passphrase, + key_prefix, loop) + + key = import_private_key(key_data, passphrase, + unsafe_skip_rsa_key_validation) + key.set_filename(key_prefix) + except KeyImportError as exc: + if skip_public or (ignore_encrypted and + str(exc).startswith('Passphrase')): + continue + + raise + else: + key = cast(Union[SSHKey, SSHKeyPair], key_to_load) + if not certs: if isinstance(key, SSHKeyPair): pubdata = key.key_public_data @@ -3523,9 +3710,9 @@ def load_keypairs(keylist: KeyPairListArg, result.append(key) else: if cert: - result.append(SSHLocalKeyPair(key, pubkey, cert)) + result.append(SSHLocalKeyPair(key, pubkey, cert, enc_key)) - result.append(SSHLocalKeyPair(key, pubkey)) + result.append(SSHLocalKeyPair(key, pubkey, None, enc_key)) return result @@ -3696,7 +3883,8 @@ def load_default_identities() -> Sequence[bytes]: return result -def load_resident_keys(pin: str, *, application: str = 'ssh:', user: str = None, +def load_resident_keys(pin: str, *, application: str = 'ssh:', + user: Optional[str] = None, touch_required: bool = True) -> Sequence[SSHKey]: """Load keys resident on attached FIDO2 security keys @@ -3722,7 +3910,6 @@ def load_resident_keys(pin: str, *, application: str = 'ssh:', user: str = None, """ - application = application.encode('utf-8') flags = SSH_SK_USER_PRESENCE_REQD if touch_required else 0 reserved = b'' diff --git a/asyncssh/rsa.py b/asyncssh/rsa.py index 09edc59..8846a5a 100644 --- a/asyncssh/rsa.py +++ b/asyncssh/rsa.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2022 by Ron Frederick and others. +# Copyright (c) 2013-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -43,22 +43,65 @@ _PrivateKeyArgs = Tuple[int, int, int, int, int, int, int, int] +_PrivateKeyConstructArgs = Tuple[int, int, int, int, int, int, int, int, bool] _PublicKeyArgs = Tuple[int, int] +_default_skip_rsa_key_validation = False + + +def set_default_skip_rsa_key_validation(skip_validation: bool) -> None: + """Set whether to disable RSA key validation in OpenSSL + + OpenSSL 3.x does additional validation when loading RSA keys + as an added security measure. However, the result is that + loading a key can take significantly longer than it did before. + + If all your RSA keys are coming from a trusted source, you can + call this function with a value of `True` to default to skipping + these checks on RSA keys, reducing the cost back down to what it + was in earlier releases. + + This can also be set on a case by case basis by using the new + `unsafe_skip_rsa_key_validation` argument on the functions used + to load keys. This will only affect loading keys of type RSA. + + .. note:: The extra cost only applies to loading existing keys, and + not to generating new keys. Also, in cases where a key is + used repeatedly, it can be loaded once into an `SSHKey` + object and reused without having to pay the cost each time. + So, this call should not be needed in most applications. + + If an application does need this, it is strongly + recommended that the `unsafe_skip_rsa_key_validation` + argument be used rather than using this function to + change the default behavior for all load operations. + + """ + + # pylint: disable=global-statement + + global _default_skip_rsa_key_validation + + _default_skip_rsa_key_validation = skip_validation + + class RSAKey(SSHKey): """Handler for RSA public key encryption""" _key: Union[RSAPrivateKey, RSAPublicKey] algorithm = b'ssh-rsa' - default_hash_name = 'sha256' + default_x509_hash = 'sha256' pem_name = b'RSA' pkcs8_oid = ObjectIdentifier('1.2.840.113549.1.1.1') sig_algorithms = (b'rsa-sha2-256', b'rsa-sha2-512', b'ssh-rsa-sha224@ssh.com', b'ssh-rsa-sha256@ssh.com', b'ssh-rsa-sha384@ssh.com', b'ssh-rsa-sha512@ssh.com', b'ssh-rsa') + cert_sig_algorithms = (b'rsa-sha2-256', b'rsa-sha2-512', b'ssh-rsa') + cert_algorithms = tuple(alg + b'-cert-v01@openssh.com' + for alg in cert_sig_algorithms) x509_sig_algorithms = (b'rsa2048-sha256', b'ssh-rsa') x509_algorithms = tuple(b'x509v3-' + alg for alg in x509_sig_algorithms) all_sig_algorithms = set(x509_sig_algorithms + sig_algorithms) @@ -91,9 +134,14 @@ def generate(cls, algorithm: bytes, *, # type: ignore def make_private(cls, key_params: object) -> SSHKey: """Construct an RSA private key""" - n, e, d, p, q, dmp1, dmq1, iqmp = cast(_PrivateKeyArgs, key_params) + n, e, d, p, q, dmp1, dmq1, iqmp, unsafe_skip_rsa_key_validation = \ + cast(_PrivateKeyConstructArgs, key_params) + + if unsafe_skip_rsa_key_validation is None: + unsafe_skip_rsa_key_validation = _default_skip_rsa_key_validation - return cls(RSAPrivateKey.construct(n, e, d, p, q, dmp1, dmq1, iqmp)) + return cls(RSAPrivateKey.construct(n, e, d, p, q, dmp1, dmq1, iqmp, + unsafe_skip_rsa_key_validation)) @classmethod def make_public(cls, key_params: object) -> SSHKey: @@ -265,7 +313,7 @@ def decrypt(self, data: bytes, algorithm: bytes) -> Optional[bytes]: register_public_key_alg(b'ssh-rsa', RSAKey, True) -for _alg in (b'rsa-sha2-256', b'rsa-sha2-512', b'ssh-rsa'): +for _alg in RSAKey.cert_sig_algorithms: register_certificate_alg(1, _alg, _alg + b'-cert-v01@openssh.com', RSAKey, SSHOpenSSHCertificateV01, True) diff --git a/asyncssh/saslprep.py b/asyncssh/saslprep.py index fe6a2cb..9b73661 100644 --- a/asyncssh/saslprep.py +++ b/asyncssh/saslprep.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -31,9 +31,11 @@ # pylint: disable=deprecated-module import stringprep # pylint: enable=deprecated-module -from typing import Callable, Sequence import unicodedata +from typing import Callable, Optional, Sequence +from typing_extensions import Literal + class SASLPrepError(ValueError): """Invalid data provided to saslprep""" @@ -60,15 +62,17 @@ def _check_bidi(s: str) -> None: raise SASLPrepError('RandALCat character not at both start and end') -def _stringprep(s: str, check_unassigned: bool, mapping: Callable[[str], str], - normalization: str, prohibited: Sequence[Callable[[str], bool]], +def _stringprep(s: str, check_unassigned: bool, + mapping: Optional[Callable[[str], str]], + normalization: Literal['NFC', 'NFD', 'NFKC', 'NFKD'], + prohibited: Sequence[Callable[[str], bool]], bidi: bool) -> str: """Implement a stringprep profile as defined in RFC 3454""" if check_unassigned: # pragma: no branch for c in s: if stringprep.in_table_a1(c): - raise SASLPrepError('Unassigned character: %r' % c) + raise SASLPrepError(f'Unassigned character: {c!r}') if mapping: # pragma: no branch s = mapping(s) @@ -80,7 +84,7 @@ def _stringprep(s: str, check_unassigned: bool, mapping: Callable[[str], str], for c in s: for lookup in prohibited: if lookup(c): - raise SASLPrepError('Prohibited character: %r' % c) + raise SASLPrepError(f'Prohibited character: {c!r}') if bidi: # pragma: no branch _check_bidi(s) diff --git a/asyncssh/scp.py b/asyncssh/scp.py index 4504701..f345d17 100644 --- a/asyncssh/scp.py +++ b/asyncssh/scp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2021 by Ron Frederick and others. +# Copyright (c) 2017-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -24,25 +24,25 @@ import argparse import asyncio -import posixpath +import inspect from pathlib import PurePath +import posixpath import shlex -import stat import string import sys from types import TracebackType -from typing import TYPE_CHECKING, List, NoReturn, Optional +from typing import TYPE_CHECKING, AsyncIterator, List, NoReturn, Optional from typing import Sequence, Tuple, Type, Union, cast -from typing_extensions import Protocol +from typing_extensions import Protocol, Self from .constants import DEFAULT_LANG +from .constants import FILEXFER_TYPE_REGULAR, FILEXFER_TYPE_DIRECTORY from .logging import SSHLogger from .misc import BytesOrStr, FilePath, HostPort, MaybeAwait from .misc import async_context_manager, plural -from .sftp import SFTPAttrs, SFTPServer, SFTPServerFS, SFTPFileProtocol +from .sftp import SFTPAttrs, SFTPGlob, SFTPName, SFTPServer, SFTPServerFS from .sftp import SFTPError, SFTPFailure, SFTPBadMessage, SFTPConnectionLost -from .sftp import SFTPErrorHandler, SFTPProgressHandler -from .sftp import SFTP_BLOCK_SIZE, local_fs, match_glob +from .sftp import SFTPErrorHandler, SFTPProgressHandler, local_fs if TYPE_CHECKING: @@ -57,6 +57,30 @@ _SCPConnPath = Union[Tuple[_SCPConn, _SCPPath], _SCPConn, _SCPPath] +_SCP_BLOCK_SIZE = 256*1024 # 256 KiB + + +class _SCPFileProtocol(Protocol): + """Protocol for accessing a file during an SCP copy""" + + async def __aenter__(self) -> Self: + """Allow _SCPFileProtocol to be used as an async context manager""" + + async def __aexit__(self, _exc_type: Optional[Type[BaseException]], + _exc_value: Optional[BaseException], + _traceback: Optional[TracebackType]) -> bool: + """Wait for file close when used as an async context manager""" + + async def read(self, size: int, offset: int) -> bytes: + """Read data from the local file""" + + async def write(self, data: bytes, offset: int) -> int: + """Write data to the local file""" + + async def close(self) -> None: + """Close the local file""" + + class _SCPFSProtocol(Protocol): """Protocol for accessing a filesystem during an SCP copy""" @@ -76,14 +100,14 @@ async def exists(self, path: bytes) -> bool: async def isdir(self, path: bytes) -> bool: """Return if the path refers to a directory""" - async def listdir(self, path: bytes) -> Sequence[bytes]: - """List the contents of a directory""" + def scandir(self, path: bytes) -> AsyncIterator[SFTPName]: + """Read the names and attributes of files in a directory""" async def mkdir(self, path: bytes) -> None: """Create a directory""" @async_context_manager - async def open(self, path: bytes, mode: str) -> SFTPFileProtocol: + async def open(self, path: bytes, mode: str) -> _SCPFileProtocol: """Open a file""" @@ -240,7 +264,7 @@ def __init__(self, reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]', self._logger = reader.logger.get_child('sftp') - async def __aenter__(self) -> '_SCPHandler': # pragma: no cover + async def __aenter__(self) -> Self: # pragma: no cover """Allow _SCPHandler to be used as an async context manager""" return self @@ -383,22 +407,25 @@ def handle_error(self, exc: Exception) -> None: self.logger.debug1('Handling SCP error: %s', str(exc)) - if getattr(exc, 'fatal', False) or self._error_handler is None: - raise exc from None - elif self._error_handler: + if self._error_handler and not getattr(exc, 'fatal', False): self._error_handler(exc) + elif not self._server: + raise exc - async def close(self) -> None: + async def close(self, cancelled: bool = False) -> None: """Close an SCP session""" self.logger.info('Stopping remote SCP') - if self._server: - cast('SSHServerChannel', self._writer.channel).exit(0) + if cancelled: + self._writer.channel.abort() else: - self._writer.close() + if self._server: + cast('SSHServerChannel', self._writer.channel).exit(0) + else: + self._writer.close() - await self._writer.channel.wait_closed() + await self._writer.wait_closed() class _SCPSource(_SCPHandler): @@ -406,7 +433,7 @@ class _SCPSource(_SCPHandler): def __init__(self, fs: _SCPFSProtocol, reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]', preserve: bool, recurse: bool, - block_size: int = SFTP_BLOCK_SIZE, + block_size: int = _SCP_BLOCK_SIZE, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None, server: bool = False): super().__init__(reader, writer, error_handler, server) @@ -423,7 +450,7 @@ async def _make_cd_request(self, action: bytes, attrs: SFTPAttrs, assert attrs.permissions is not None - args = '%04o %d ' % (attrs.permissions & 0o7777, size) + args = f'{attrs.permissions & 0o7777:04o} {size} ' await self.make_request(action, args.encode('ascii'), self._fs.basename(path)) @@ -436,7 +463,7 @@ async def _make_t_request(self, attrs: SFTPAttrs) -> None: assert attrs.mtime is not None assert attrs.atime is not None - args = '%d 0 %d 0' % (attrs.mtime, attrs.atime) + args = f'{attrs.mtime} 0 {attrs.atime} 0' await self.make_request(b'T', args.encode('ascii')) async def _send_file(self, srcpath: bytes, @@ -493,38 +520,39 @@ async def _send_file(self, srcpath: bytes, if final_exc: raise final_exc - async def _send_dir(self, srcpath: bytes, - dstpath: bytes, attrs: SFTPAttrs) -> None: + async def _send_dir(self, srcpath: bytes, dstpath: bytes, + attrs: SFTPAttrs) -> None: """Send directory over SCP""" self.logger.info(' Starting send of directory %s', srcpath) await self._make_cd_request(b'D', attrs, 0, srcpath) - for name in await self._fs.listdir(srcpath): + async for entry in self._fs.scandir(srcpath): + name = cast(bytes, entry.filename) + if name in (b'.', b'..'): continue await self._send_files(posixpath.join(srcpath, name), - posixpath.join(dstpath, name)) + posixpath.join(dstpath, name), + entry.attrs) await self.make_request(b'E') self.logger.info(' Finished send of directory %s', srcpath) - async def _send_files(self, srcpath: bytes, dstpath: bytes) -> None: + async def _send_files(self, srcpath: bytes, dstpath: bytes, + attrs: SFTPAttrs) -> None: """Send files via SCP""" try: - attrs = await self._fs.stat(srcpath) - assert attrs.permissions is not None - if self._preserve: await self._make_t_request(attrs) - if self._recurse and stat.S_ISDIR(attrs.permissions): + if self._recurse and attrs.type == FILEXFER_TYPE_DIRECTORY: await self._send_dir(srcpath, dstpath, attrs) - elif stat.S_ISREG(attrs.permissions): + elif attrs.type == FILEXFER_TYPE_REGULAR: await self._send_file(srcpath, dstpath, attrs) else: raise _scp_error(SFTPFailure, 'Not a regular file', srcpath) @@ -534,6 +562,8 @@ async def _send_files(self, srcpath: bytes, dstpath: bytes) -> None: async def run(self, srcpath: _SCPPath) -> None: """Start SCP transfer""" + cancelled = False + try: if isinstance(srcpath, PurePath): srcpath = str(srcpath) @@ -546,12 +576,16 @@ async def run(self, srcpath: _SCPPath) -> None: if exc: raise exc - for path in await match_glob(self._fs, srcpath): - await self._send_files(path, b'') + for name in await SFTPGlob(self._fs).match(srcpath): + await self._send_files(cast(bytes, name.filename), + b'', name.attrs) + except (KeyboardInterrupt, asyncio.CancelledError): + cancelled = True + raise except (OSError, SFTPError) as exc: self.handle_error(exc) finally: - await self.close() + await self.close(cancelled) class _SCPSink(_SCPHandler): @@ -559,7 +593,7 @@ class _SCPSink(_SCPHandler): def __init__(self, fs: _SCPFSProtocol, reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]', must_be_dir: bool, preserve: bool, - recurse: bool, block_size: int = SFTP_BLOCK_SIZE, + recurse: bool, block_size: int = _SCP_BLOCK_SIZE, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None, server: bool = False): super().__init__(reader, writer, error_handler, server) @@ -657,7 +691,8 @@ async def _recv_files(self, srcpath: bytes, dstpath: bytes) -> None: try: if action in b'\x01\x02': - raise _scp_error(SFTPFailure, args, fatal=action != b'\x01', + raise _scp_error(SFTPFailure, args, + fatal=action != b'\x01', suppress_send=True) elif action == b'T': if self._preserve: @@ -697,6 +732,8 @@ async def _recv_files(self, srcpath: bytes, dstpath: bytes) -> None: async def run(self, dstpath: _SCPPath) -> None: """Start SCP file receive""" + cancelled = False + try: if isinstance(dstpath, PurePath): dstpath = str(dstpath) @@ -709,10 +746,13 @@ async def run(self, dstpath: _SCPPath) -> None: dstpath)) else: await self._recv_files(b'', dstpath) + except (KeyboardInterrupt, asyncio.CancelledError): + cancelled = True + raise except (OSError, SFTPError, ValueError) as exc: self.handle_error(exc) finally: - await self.close() + await self.close(cancelled) class _SCPCopier: @@ -722,7 +762,7 @@ def __init__(self, src_reader: 'SSHReader[bytes]', src_writer: 'SSHWriter[bytes]', dst_reader: 'SSHReader[bytes]', dst_writer: 'SSHWriter[bytes]', - block_size: int = SFTP_BLOCK_SIZE, + block_size: int = _SCP_BLOCK_SIZE, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None): self._source = _SCPHandler(src_reader, src_writer) @@ -742,7 +782,8 @@ def _handle_error(self, exc: Exception) -> None: """Handle an SCP error""" if isinstance(exc, BrokenPipeError): - exc = _scp_error(SFTPConnectionLost, 'Connection lost', fatal=True) + exc = _scp_error(SFTPConnectionLost, 'Connection lost', + fatal=True, suppress_send=True) self.logger.debug1('Handling SCP error: %s', str(exc)) @@ -785,7 +826,7 @@ async def _copy_file(self, path: bytes, size: int) -> None: if not data: raise _scp_error(SFTPConnectionLost, 'Connection lost', - fatal=True) + fatal=True, suppress_send=True) await self._sink.send_data(data) offset += len(data) @@ -868,18 +909,23 @@ async def _copy_files(self) -> None: async def run(self) -> None: """Start SCP remote-to-remote transfer""" + cancelled = False + try: await self._copy_files() + except (KeyboardInterrupt, asyncio.CancelledError): + cancelled = True + raise except (OSError, SFTPError) as exc: self._handle_error(exc) finally: - await self._source.close() - await self._sink.close() + await self._source.close(cancelled) + await self._sink.close(cancelled) async def scp(srcpaths: Union[_SCPConnPath, Sequence[_SCPConnPath]], dstpath: _SCPConnPath = None, *, preserve: bool = False, - recurse: bool = False, block_size: int = SFTP_BLOCK_SIZE, + recurse: bool = False, block_size: int = _SCP_BLOCK_SIZE, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None, **kwargs) -> None: """Copy files using SCP @@ -936,7 +982,7 @@ async def scp(srcpaths: Union[_SCPConnPath, Sequence[_SCPConnPath]], SFTP instead. The block_size value controls the size of read and write operations - issued to copy the files. It defaults to 16 KB. + issued to copy the files. It defaults to 256 KB. If progress_handler is specified, it will be called after each block of a file is successfully copied. The arguments passed to @@ -1044,22 +1090,48 @@ async def scp(srcpaths: Union[_SCPConnPath, Sequence[_SCPConnPath]], await dstconn.wait_closed() -def run_scp_server(sftp_server: SFTPServer, command: str, +async def _scp_handler(sftp_server: MaybeAwait[SFTPServer], + args: _SCPArgs, reader: 'SSHReader[bytes]', + writer: 'SSHWriter[bytes]') -> None: + """Run an SCP server to handle this request""" + + if inspect.isawaitable(sftp_server): + sftp_server = await sftp_server + + sftp_server: SFTPServer + + fs = SFTPServerFS(sftp_server) + + handler: Union[_SCPSource, _SCPSink] + + if args.source: + handler = _SCPSource(fs, reader, writer, args.preserve, + args.recurse, error_handler=False, server=True) + else: + handler = _SCPSink(fs, reader, writer, args.must_be_dir, + args.preserve, args.recurse, + error_handler=False, server=True) + + try: + await handler.run(args.path) + finally: + result = sftp_server.exit() + + if inspect.isawaitable(result): + await result + + +def run_scp_server(sftp_server: MaybeAwait[SFTPServer], command: str, stdin: 'SSHReader[bytes]', stdout: 'SSHWriter[bytes]', stderr: 'SSHWriter[bytes]') -> MaybeAwait[None]: """Return a handler for an SCP server session""" - async def _run_handler() -> None: - """Run an SCP server to handle this request""" - - try: - await handler.run(args.path) - finally: - sftp_server.exit() - try: args = _SCPArgParser().parse(command) except ValueError as exc: + if inspect.iscoroutine(sftp_server): + sftp_server.close() + stdin.logger.info('Error starting SCP server: %s', str(exc)) stderr.write(b'scp: ' + str(exc).encode('utf-8') + b'\n') cast('SSHServerChannel', stderr.channel).exit(1) @@ -1067,15 +1139,4 @@ async def _run_handler() -> None: stdin.logger.info('Starting SCP server, args: %s', command[4:].strip()) - fs = SFTPServerFS(sftp_server) - - handler: Union[_SCPSource, _SCPSink] - - if args.source: - handler = _SCPSource(fs, stdin, stdout, args.preserve, args.recurse, - error_handler=False, server=True) - else: - handler = _SCPSink(fs, stdin, stdout, args.must_be_dir, args.preserve, - args.recurse, error_handler=False, server=True) - - return _run_handler() + return _scp_handler(sftp_server, args, stdin, stdout) diff --git a/asyncssh/server.py b/asyncssh/server.py index db533b6..d80934b 100644 --- a/asyncssh/server.py +++ b/asyncssh/server.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -31,21 +31,36 @@ if TYPE_CHECKING: # pylint: disable=cyclic-import - from .connection import SSHServerConnection + from .connection import SSHClientConnection, SSHServerConnection + from .connection import SSHAcceptHandler from .channel import SSHServerChannel, SSHTCPChannel, SSHUNIXChannel + from .channel import SSHTunTapChannel from .session import SSHServerSession, SSHTCPSession, SSHUNIXSession - - -_NewSession = Union[bool, 'SSHServerSession', SSHServerSessionFactory, - Tuple['SSHServerChannel', 'SSHServerSession'], - Tuple['SSHServerChannel', SSHServerSessionFactory]] -_NewTCPSession = Union[bool, 'SSHTCPSession', SSHSocketSessionFactory, - Tuple['SSHTCPChannel', 'SSHTCPSession'], - Tuple['SSHTCPChannel', SSHSocketSessionFactory]] -_NewUNIXSession = Union[bool, 'SSHUNIXSession', SSHSocketSessionFactory, - Tuple['SSHUNIXChannel', 'SSHUNIXSession'], - Tuple['SSHUNIXChannel', SSHSocketSessionFactory]] -_NewListener = Union[bool, SSHListener] + from .session import SSHTunTapSession + + +_NewSession = Union[ + bool, 'SSHClientConnection', + MaybeAwait['SSHServerSession'], SSHServerSessionFactory, + Tuple['SSHServerChannel', MaybeAwait['SSHServerSession']], + Tuple['SSHServerChannel', SSHServerSessionFactory]] +_NewTCPSession = Union[ + bool, 'SSHClientConnection', + MaybeAwait['SSHTCPSession'], SSHSocketSessionFactory, + Tuple['SSHTCPChannel', MaybeAwait['SSHTCPSession']], + Tuple['SSHTCPChannel', SSHSocketSessionFactory]] +_NewUNIXSession = Union[ + bool, 'SSHClientConnection', + MaybeAwait['SSHUNIXSession'], SSHSocketSessionFactory, + Tuple['SSHUNIXChannel', MaybeAwait['SSHUNIXSession']], + Tuple['SSHUNIXChannel', SSHSocketSessionFactory]] +_NewTunTapSession = Union[ + bool, 'SSHClientConnection', + MaybeAwait['SSHTunTapSession'], SSHSocketSessionFactory, + Tuple['SSHTunTapChannel', MaybeAwait['SSHTunTapSession']], + Tuple['SSHTunTapChannel', SSHSocketSessionFactory]] +_NewTCPListener = Union[bool, 'SSHAcceptHandler', MaybeAwait[SSHListener]] +_NewUNIXListener = Union[bool, MaybeAwait[SSHListener]] class SSHServer: @@ -54,6 +69,14 @@ class SSHServer: Applications may subclass this when implementing an SSH server to provide custom authentication and request handlers. + Whenever a new SSH server connection is accepted, a corresponding + SSHServer object is created and the method :meth:`connection_made` + is called, passing in the :class:`SSHServerConnection` object. + + When the connection is closed, the method :meth:`connection_lost` + is called with an exception representing the reason for the + disconnect, or `None` if the connection was closed cleanly. + The method :meth:`begin_auth` can be overridden decide whether or not authentication is required, and additional callbacks are provided for each form of authentication in cases where authentication @@ -144,16 +167,19 @@ def begin_auth(self, username: str) -> MaybeAwait[bool]: return True # pragma: no cover - def auth_completed(self) -> None: + def auth_completed(self) -> MaybeAwait[None]: """Authentication was completed successfully This method is called when authentication has completed - succesfully. Applications may use this method to perform + successfully. Applications may use this method to perform processing based on the authenticated username or options in the authorized keys list or certificate associated with the user before any sessions are opened or forwarding requests are handled. + If blocking operations need to be performed when authentication + completes, this method may be defined as a coroutine. + """ def validate_gss_principal(self, username: str, user_principal: str, @@ -575,7 +601,7 @@ def kbdint_auth_supported(self) -> bool: authentication is supported. Applications wishing to support it must have this method return `True` and implement :meth:`get_kbdint_challenge` and :meth:`validate_kbdint_response` - to generate the apporiate challenges and validate the responses + to generate the appropriate challenges and validate the responses for the user being authenticated. By default, this method returns `NotImplemented` tying @@ -677,7 +703,7 @@ def session_requested(self) -> MaybeAwait[_NewSession]: If blocking operations need to be performed before the session can be created, a coroutine which returns an :class:`SSHServerSession` object can be returned instead of - the session iself. This can be either returned directly or as + the session itself. This can be either returned directly or as a part of a tuple with an :class:`SSHServerChannel` object. To reject this request, this method should return `False` @@ -729,6 +755,11 @@ def connection_requested(self, dest_host: str, dest_port: int, :exc:`ChannelOpenError` exception with the reason for the failure. + If the application wishes to tunnel the connection over + another SSH connection, this method should return an + :class:`SSHClientConnection` connected to the desired + tunnel host. + If the application wishes to process the data on the connection itself, this method should return either an :class:`SSHTCPSession` object which can be used to process the @@ -742,7 +773,7 @@ def connection_requested(self, dest_host: str, dest_port: int, If blocking operations need to be performed before the session can be created, a coroutine which returns an :class:`SSHTCPSession` object can be returned instead of - the session iself. This can be either returned directly or as + the session itself. This can be either returned directly or as a part of a tuple with an :class:`SSHTCPChannel` object. By default, all connection requests are rejected. @@ -782,7 +813,7 @@ def connection_requested(self, dest_host: str, dest_port: int, return False # pragma: no cover def server_requested(self, listen_host: str, - listen_port: int) -> MaybeAwait[_NewListener]: + listen_port: int) -> MaybeAwait[_NewTCPListener]: """Handle a request to listen on a TCP/IP address and port This method is called when a client makes a request to @@ -844,6 +875,11 @@ def unix_connection_requested(self, dest_path: str) -> _NewUNIXSession: :exc:`ChannelOpenError` exception with the reason for the failure. + If the application wishes to tunnel the connection over + another SSH connection, this method should return an + :class:`SSHClientConnection` connected to the desired + tunnel host. + If the application wishes to process the data on the connection itself, this method should return either an :class:`SSHUNIXSession` object which can be used to process the @@ -857,7 +893,7 @@ def unix_connection_requested(self, dest_path: str) -> _NewUNIXSession: If blocking operations need to be performed before the session can be created, a coroutine which returns an :class:`SSHUNIXSession` object can be returned instead of - the session iself. This can be either returned directly or as + the session itself. This can be either returned directly or as a part of a tuple with an :class:`SSHUNIXChannel` object. By default, all connection requests are rejected. @@ -888,7 +924,7 @@ def unix_connection_requested(self, dest_path: str) -> _NewUNIXSession: return False # pragma: no cover def unix_server_requested(self, listen_path: str) -> \ - MaybeAwait[_NewListener]: + MaybeAwait[_NewUNIXListener]: """Handle a request to listen on a UNIX domain socket This method is called when a client makes a request to @@ -930,3 +966,129 @@ def unix_server_requested(self, listen_path: str) -> \ """ return False # pragma: no cover + + def tun_requested(self, unit: Optional[int]) -> _NewTunTapSession: + """Handle a layer 3 tunnel request + + This method is called when a layer 3 tunnel request is received + by the server. Applications wishing to accept such tunnels must + override this method. + + To allow standard forwarding of data on the connection to the + requested TUN device, this method should return `True`. + + To reject this request, this method should return `False` + to send back a "Connection refused" response or raise an + :exc:`ChannelOpenError` exception with the reason for + the failure. + + If the application wishes to tunnel the data over another + SSH connection, this method should return an + :class:`SSHClientConnection` connected to the desired + tunnel host. + + If the application wishes to process the data on the + connection itself, this method should return either an + :class:`SSHTunTapSession` object which can be used to process the + data received on the channel or a tuple consisting of of an + :class:`SSHTunTapChannel` object created with + :meth:`create_tuntap_channel() + ` and an + :class:`SSHTunTapSession`, if the application wishes + to pass non-default arguments when creating the channel. + + If blocking operations need to be performed before the session + can be created, a coroutine which returns an + :class:`SSHTunTapSession` object can be returned instead of + the session itself. This can be either returned directly or as + a part of a tuple with an :class:`SSHTunTapChannel` object. + + By default, all layer 3 tunnel requests are rejected. + + :param dest_path: + The path the client wishes to connect to + :type dest_path: `str` + + :returns: One of the following: + + * An :class:`SSHTunTapSession` object or a coroutine + which returns an :class:`SSHTunTapSession` + * A tuple consisting of an :class:`SSHTunTapChannel` + and the above + * A `callable` or coroutine handler function which + takes AsyncSSH stream objects for reading from + and writing to the connection + * A tuple consisting of an :class:`SSHTunTapChannel` + and the above + * `True` to request standard layer 3 tunnel forwarding + * `False` to refuse the connection + + :raises: :exc:`ChannelOpenError` if the connection shouldn't + be accepted + + """ + + return False # pragma: no cover + + def tap_requested(self, unit: Optional[int]) -> _NewTunTapSession: + """Handle a layer 2 tunnel request + + This method is called when a layer 2 tunnel request is received + by the server. Applications wishing to accept such tunnels must + override this method. + + To allow standard forwarding of data on the connection to the + requested TAP device, this method should return `True`. + + To reject this request, this method should return `False` + to send back a "Connection refused" response or raise an + :exc:`ChannelOpenError` exception with the reason for + the failure. + + If the application wishes to tunnel the data over another + SSH connection, this method should return an + :class:`SSHClientConnection` connected to the desired + tunnel host. + + If the application wishes to process the data on the + connection itself, this method should return either an + :class:`SSHTunTapSession` object which can be used to process the + data received on the channel or a tuple consisting of of an + :class:`SSHTunTapChannel` object created with + :meth:`create_tuntap_channel() + ` and an + :class:`SSHTunTapSession`, if the application wishes + to pass non-default arguments when creating the channel. + + If blocking operations need to be performed before the session + can be created, a coroutine which returns an + :class:`SSHTunTapSession` object can be returned instead of + the session itself. This can be either returned directly or as + a part of a tuple with an :class:`SSHTunTapChannel` object. + + By default, all layer 2 tunnel requests are rejected. + + :param dest_path: + The path the client wishes to connect to + :type dest_path: `str` + + :returns: One of the following: + + * An :class:`SSHTunTapSession` object or a coroutine + which returns an :class:`SSHTunTapSession` + * A tuple consisting of an :class:`SSHTunTapChannel` + and the above + * A `callable` or coroutine handler function which + takes AsyncSSH stream objects for reading from + and writing to the connection + * A tuple consisting of an :class:`SSHTunTapChannel` + and the above + * `True` to request standard layer 2 tunnel forwarding + * `False` to refuse the connection + + :raises: :exc:`ChannelOpenError` if the connection shouldn't + be accepted + + """ + + return False # pragma: no cover diff --git a/asyncssh/session.py b/asyncssh/session.py index 57d2b02..329975c 100644 --- a/asyncssh/session.py +++ b/asyncssh/session.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -21,21 +21,16 @@ """SSH session handlers""" from typing import TYPE_CHECKING, Any, AnyStr, Callable, Generic -from typing import Mapping, Optional, Tuple, Union +from typing import Mapping, Optional, Tuple if TYPE_CHECKING: # pylint: disable=cyclic-import from .channel import SSHClientChannel, SSHServerChannel - from .channel import SSHTCPChannel, SSHUNIXChannel + from .channel import SSHTCPChannel, SSHUNIXChannel, SSHTunTapChannel DataType = Optional[int] -TermModes = Mapping[int, int] -TermModesArg = Optional[TermModes] -TermSize = Tuple[int, int, int, int] -TermSizeArg = Union[None, Tuple[int, int], TermSize] - class SSHSession(Generic[AnyStr]): """SSH session handler""" @@ -262,7 +257,7 @@ def connection_made(self, chan: 'SSHServerChannel[AnyStr]') -> None: def pty_requested(self, term_type: str, term_size: Tuple[int, int, int, int], term_modes: Mapping[int, int]) -> bool: - """A psuedo-terminal has been requested + """A pseudo-terminal has been requested This method is called when the client sends a request to allocate a pseudo-terminal with the requested terminal type, size, and @@ -537,7 +532,45 @@ def connection_made(self, chan: 'SSHUNIXChannel[AnyStr]') -> None: """ +class SSHTunTapSession(SSHSession[bytes]): + """SSH TUN/TAP session handler + + Applications should subclass this when implementing a handler for + SSH TUN/TAP tunnels. + + SSH client applications wishing to open a tunnel should call + :meth:`create_tun() ` or + :meth:`create_tap() ` on their + :class:`SSHClientConnection`, passing in a factory which returns + instances of this class. + + Server applications wishing to allow tunnel connections should + implement the coroutine :meth:`tun_requested() + ` or :meth:`tap_requested() + ` on their :class:`SSHServer` object + and have it return instances of this class. + + When a connection is successfully opened, :meth:`session_started` + will be called, after which the application can begin sending data. + Received data will be passed to the :meth:`data_received` method. + + """ + + def connection_made(self, chan: 'SSHTunTapChannel') -> None: + """Called when a channel is opened successfully + + This method is called when a channel is opened successfully. The + channel parameter should be stored if needed for later use. + + :param chan: + The channel which was successfully opened. + :type chan: :class:`SSHTunTapChannel` + + """ + + SSHSessionFactory = Callable[[], SSHSession[AnyStr]] SSHClientSessionFactory = Callable[[], SSHClientSession[AnyStr]] SSHTCPSessionFactory = Callable[[], SSHTCPSession[AnyStr]] SSHUNIXSessionFactory = Callable[[], SSHUNIXSession[AnyStr]] +SSHTunTapSessionFactory = Callable[[], SSHTunTapSession] diff --git a/asyncssh/sftp.py b/asyncssh/sftp.py index 8563de8..6d41ac2 100644 --- a/asyncssh/sftp.py +++ b/asyncssh/sftp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2022 by Ron Frederick and others. +# Copyright (c) 2015-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -37,7 +37,7 @@ from typing import TYPE_CHECKING, AnyStr, AsyncIterator, Awaitable, Callable from typing import Dict, Generic, IO, Iterable, List, Mapping, Optional from typing import Sequence, Set, Tuple, Type, TypeVar, Union, cast, overload -from typing_extensions import Literal, Protocol +from typing_extensions import Literal, Protocol, Self from . import constants from .constants import DEFAULT_LANG @@ -106,13 +106,27 @@ from .misc import BytesOrStr, Error, FilePath, MaybeAwait, OptExcInfo, Record from .misc import ConnectionLost -from .misc import async_context_manager, get_symbol_names, hide_empty, plural +from .misc import async_context_manager, get_symbol_names, hide_empty +from .misc import make_sparse_file, plural from .packet import Boolean, Byte, String, UInt16, UInt32, UInt64 from .packet import PacketDecodeError, SSHPacket, SSHPacketLogger from .version import __author__, __version__ +_pywin32_available = False + +if sys.platform == 'win32': # pragma: no cover + try: + import msvcrt + import pywintypes + import win32file + import winerror + import winioctlcon + _pywin32_available = True + except ImportError: + pass + if TYPE_CHECKING: # pylint: disable=cyclic-import @@ -133,33 +147,44 @@ _SFTPFileObj = IO[bytes] _SFTPPath = Union[bytes, FilePath] +_SFTPPaths = Union[_SFTPPath, Sequence[_SFTPPath]] +_SFTPPatList = List[Union[bytes, List[bytes]]] _SFTPStatFunc = Callable[[_SFTPPath], Awaitable['SFTPAttrs']] +_SFTPClientFileOrPath = Union['SFTPClientFile', _SFTPPath] + _SFTPNames = Tuple[Sequence['SFTPName'], bool] _SFTPOSAttrs = Union[os.stat_result, 'SFTPAttrs'] _SFTPOSVFSAttrs = Union[os.statvfs_result, 'SFTPVFSAttrs'] -_SFTPOnErrorHandler = Callable[[Callable, bytes, OptExcInfo], None] -_SFTPPacketHandler = Callable[['SFTPServerHandler', SSHPacket], - Awaitable[object]] +_SFTPOnErrorHandler = Optional[Callable[[Callable, bytes, OptExcInfo], None]] +_SFTPPacketHandler = Optional[Callable[['SFTPServerHandler', SSHPacket], + Awaitable[object]]] SFTPErrorHandler = Union[None, Literal[False], Callable[[Exception], None]] SFTPProgressHandler = Optional[Callable[[bytes, bytes, int, int], None]] +_T = TypeVar('_T') + MIN_SFTP_VERSION = 3 MAX_SFTP_VERSION = 6 -SFTP_BLOCK_SIZE = 16384 -_MAX_SFTP_READ_SIZE = 4*1024*1024 # 4 MiB +SAFE_SFTP_READ_LEN = 16*1024 # 16 KiB +SAFE_SFTP_WRITE_LEN = 16*1024 # 16 KiB + +MAX_SFTP_READ_LEN = 4*1024*1024 # 4 MiB +MAX_SFTP_WRITE_LEN = 4*1024*1024 # 4 MiB +MAX_SFTP_PACKET_LEN = MAX_SFTP_WRITE_LEN + 1024 + +_COPY_DATA_BLOCK_SIZE = 256*1024 # 256 KiB _MAX_SFTP_REQUESTS = 128 _MAX_READDIR_NAMES = 128 +_MAX_SPARSE_RANGES = 128 _NSECS_IN_SEC = 1_000_000_000 -_T = TypeVar('_T') - _const_dict: Mapping[str, int] = constants.__dict__ @@ -199,14 +224,14 @@ class _SFTPGlobProtocol(Protocol): async def stat(self, path: bytes) -> 'SFTPAttrs': """Get attributes of a file""" - async def listdir(self, path: bytes) -> Sequence[bytes]: - """List the contents of a directory""" + def scandir(self, path: bytes) -> AsyncIterator['SFTPName']: + """Return names and attributes of the files in a directory""" class SFTPFileProtocol(Protocol): """Protocol for accessing a file via an SFTP server""" - async def __aenter__(self) -> 'SFTPFileProtocol': + async def __aenter__(self) -> Self: """Allow SFTPFileProtocol to be used as an async context manager""" async def __aexit__(self, _exc_type: Optional[Type[BaseException]], @@ -214,6 +239,10 @@ async def __aexit__(self, _exc_type: Optional[Type[BaseException]], _traceback: Optional[TracebackType]) -> bool: """Wait for file close when used as an async context manager""" + def request_ranges(self, offset: int, length: int) -> \ + AsyncIterator[Tuple[int, int]]: + """Return file ranges containing data""" + async def read(self, size: int, offset: int) -> bytes: """Read data from the local file""" @@ -227,6 +256,10 @@ async def close(self) -> None: class _SFTPFSProtocol(Protocol): """Protocol for accessing a filesystem via an SFTP server""" + @property + def limits(self) -> 'SFTPLimits': + """SFTP server limits associated with this SFTP session""" + @staticmethod def basename(path: bytes) -> bytes: """Return the final component of a POSIX-style path""" @@ -238,23 +271,19 @@ def compose_path(self, path: bytes, parent: Optional[bytes] = None) -> bytes: """Compose a path""" - async def stat(self, path: bytes) -> 'SFTPAttrs': - """Get attributes of a file or directory, following symlinks""" + async def stat(self, path: bytes, *, + follow_symlinks: bool = True) -> 'SFTPAttrs': + """Get attributes of a file, directory, or symlink""" - async def lstat(self, path: bytes) -> 'SFTPAttrs': - """Get attributes of a file or directory""" - - async def setstat(self, path: bytes, attrs: 'SFTPAttrs') -> None: - """Set attributes of a file or directory""" - - async def exists(self, path: bytes) -> bool: - """Return if a path exists""" + async def setstat(self, path: bytes, attrs: 'SFTPAttrs', *, + follow_symlinks: bool = True) -> None: + """Set attributes of a file, directory, or symlink""" async def isdir(self, path: bytes) -> bool: """Return if the path refers to a directory""" - async def listdir(self, path: bytes) -> Sequence[bytes]: - """List the contents of a directory""" + def scandir(self, path: bytes) -> AsyncIterator['SFTPName']: + """Return names and attributes of the files in a directory""" async def mkdir(self, path: bytes) -> None: """Create a directory""" @@ -266,7 +295,8 @@ async def symlink(self, oldpath: bytes, newpath: bytes) -> None: """Create a symbolic link""" @async_context_manager - async def open(self, path: bytes, mode: str) -> SFTPFileProtocol: + async def open(self, path: bytes, mode: str, + block_size: int = -1) -> SFTPFileProtocol: """Open a file""" @@ -408,7 +438,7 @@ def _utime_to_attrs(times: Optional[Tuple[float, float]] = None, else: if hasattr(time, 'time_ns'): atime, atime_ns = _nsec_to_tuple(time.time_ns()) - else: + else: # pragma: no cover atime, atime_ns = _float_sec_to_tuple(time.time()) mtime, mtime_ns = atime, atime_ns @@ -429,7 +459,7 @@ def _lookup_uid(user: Optional[str]) -> Optional[int]: try: uid = int(user) except ValueError: - raise SFTPOwnerInvalid('Invalid owner: %s' % user) from None + raise SFTPOwnerInvalid(f'Invalid owner: {user}') from None else: uid = None @@ -448,7 +478,7 @@ def _lookup_gid(group: Optional[str]) -> Optional[int]: try: gid = int(group) except ValueError: - raise SFTPGroupInvalid('Invalid group: %s' % group) from None + raise SFTPGroupInvalid(f'Invalid group: {group}') from None else: gid = None @@ -499,7 +529,7 @@ def _mode_to_pflags(mode: str) -> Tuple[int, bool]: pflags = _open_modes.get(mode) if not pflags: - raise ValueError('Invalid mode: %r' % mode) + raise ValueError(f'Invalid mode: {mode!r}') return pflags, binary @@ -564,7 +594,8 @@ def _to_local_path(path: bytes) -> _LocalPath: return path -def _setstat(path: Union[int, _SFTPPath], attrs: 'SFTPAttrs') -> None: +def _setstat(path: Union[int, _SFTPPath], attrs: 'SFTPAttrs', *, + follow_symlinks: bool = True) -> None: """Utility function to set file attributes""" if attrs.size is not None: @@ -581,7 +612,7 @@ def _setstat(path: Union[int, _SFTPPath], attrs: 'SFTPAttrs') -> None: if ((atime_ns is None and mtime_ns is not None) or (atime_ns is not None and mtime_ns is None)): - stat_result = os.stat(path) + stat_result = os.stat(path, follow_symlinks=follow_symlinks) if atime_ns is None and mtime_ns is not None: atime_ns = stat_result.st_atime_ns @@ -591,125 +622,86 @@ def _setstat(path: Union[int, _SFTPPath], attrs: 'SFTPAttrs') -> None: if uid is not None and gid is not None: try: - os.chown(path, uid, gid) + os.chown(path, uid, gid, follow_symlinks=follow_symlinks) + except NotImplementedError: # pragma: no cover + pass except AttributeError: # pragma: no cover raise NotImplementedError from None if attrs.permissions is not None: - os.chmod(path, stat.S_IMODE(attrs.permissions)) + try: + os.chmod(path, stat.S_IMODE(attrs.permissions), + follow_symlinks=follow_symlinks) + except NotImplementedError: # pragma: no cover + pass if atime_ns is not None and mtime_ns is not None: - os.utime(path, ns=(atime_ns, mtime_ns)) - - -def _split_path_by_globs(pattern: bytes) -> \ - Tuple[Optional[bytes], Sequence[object]]: - """Split path grouping parts without glob pattern""" - - basedir: Optional[bytes] = None - patlist: List[object] = [] - plain: List[bytes] = [] - - for current in pattern.split(b'/'): - if any(c in current for c in b'*?[]'): - if plain: - if patlist: - patlist.append(plain) - else: - basedir = b'/'.join(plain) or b'/' - - plain = [] - - patlist.append(current) - else: - plain.append(current) - - if plain: - patlist.append(plain) - - return basedir, patlist - - -async def _glob(fs: _SFTPGlobProtocol, basedir: Optional[bytes], - patlist: Sequence[object], result: List[bytes]) -> None: - """Recursively match a glob pattern""" - - pattern, newpatlist = patlist[0], patlist[1:] + try: + os.utime(path, ns=(atime_ns, mtime_ns), + follow_symlinks=follow_symlinks) + except NotImplementedError: # pragma: no cover + pass - names = await fs.listdir(basedir or b'.') - if isinstance(pattern, list): - if len(pattern) == 1 and not pattern[0] and not newpatlist: - result.append(basedir or b'.') - return +if sys.platform == 'win32' and _pywin32_available: # pragma: no cover + async def _request_ranges(file_obj: _SFTPFileObj, offset: int, + length: int) -> AsyncIterator[Tuple[int, int]]: + """Return file ranges containing data on Windows""" - for name in names: - if name == pattern[0]: - newbase = posixpath.join(basedir or b'', *pattern) - await fs.stat(newbase) + handle = msvcrt.get_osfhandle(file_obj.fileno()) + bufsize = _MAX_SPARSE_RANGES * 16 - if not newpatlist: - result.append(newbase) + while True: + try: + query_range = offset.to_bytes(8, 'little') + \ + length.to_bytes(8, 'little') + + ranges = win32file.DeviceIoControl( + handle, winioctlcon.FSCTL_QUERY_ALLOCATED_RANGES, + query_range, bufsize, None) + except pywintypes.error as exc: + if exc.args[0] == winerror.ERROR_MORE_DATA: + bufsize *= 2 else: - await _glob(fs, newbase, newpatlist, result) + raise + else: break - else: - if pattern == b'**': - await _glob(fs, basedir, newpatlist, result) - - for name in names: - if name in (b'.', b'..'): - continue - if fnmatch(name, cast(bytes, pattern)): - newbase = posixpath.join(basedir or b'', name) + for pos in range(0, len(ranges), 16): + offset = int.from_bytes(ranges[pos:pos+8], 'little') + length = int.from_bytes(ranges[pos+8:pos+16], 'little') + yield offset, length +elif hasattr(os, 'SEEK_DATA'): + async def _request_ranges(file_obj: _SFTPFileObj, offset: int, + length: int) -> AsyncIterator[Tuple[int, int]]: + """Return file ranges containing data""" - if not newpatlist or (len(newpatlist) == 1 and - not newpatlist[0]): - result.append(newbase) - else: - attrs = await fs.stat(newbase) - - if attrs.type == FILEXFER_TYPE_DIRECTORY: - if pattern == b'**': - await _glob(fs, newbase, patlist, result) - else: - await _glob(fs, newbase, newpatlist, result) - - -async def match_glob(fs: _SFTPGlobProtocol, pattern: bytes, - error_handler: SFTPErrorHandler = None, - sftp_version = MIN_SFTP_VERSION) -> Sequence[bytes]: - """Match a glob pattern""" + end = offset + limit = offset + length - names: List[bytes] = [] - - try: - if any(c in pattern for c in b'*?[]'): - basedir, patlist = _split_path_by_globs(pattern) - await _glob(fs, basedir, patlist, names) - - if not names: - exc = SFTPNoSuchPath if sftp_version >= 4 else SFTPNoSuchFile - raise exc('No matches found') - else: - await fs.stat(pattern) - names.append(pattern) - except (OSError, SFTPError) as exc: - setattr(exc, 'srcpath', pattern) + try: + while end < limit: + start = file_obj.seek(end, os.SEEK_DATA) + end = min(file_obj.seek(start, os.SEEK_HOLE), limit) + yield start, end - start + except OSError as exc: # pragma: no cover + if exc.errno != errno.ENXIO: + raise +else: # pragma: no cover + async def _request_ranges(file_obj: _SFTPFileObj, offset: int, + length: int) -> AsyncIterator[Tuple[int, int]]: + """Sparse files aren't supported - return the full input range""" - if error_handler: - error_handler(exc) - else: - raise + # pylint: disable=unused-argument - return names + if length: + yield offset, length class _SFTPParallelIO(Generic[_T]): """Parallelize I/O requests on files - This class issues parallel read and wite requests on files. + This class issues parallel read and write requests on files. """ @@ -719,7 +711,14 @@ def __init__(self, block_size: int, max_requests: int, self._max_requests = max_requests self._offset = offset self._bytes_left = size - self._pending: Set['asyncio.Task[None]'] = set() + self._pending: Set['asyncio.Task[Tuple[int, int, int, _T]]'] = set() + + async def _start_task(self, offset: int, size: int) -> \ + Tuple[int, int, int, _T]: + """Start a task to perform file I/O on a particular byte range""" + + count, result = await self.run_task(offset, size) + return offset, size, count, result def _start_tasks(self) -> None: """Create parallel file I/O tasks""" @@ -727,57 +726,48 @@ def _start_tasks(self) -> None: while self._bytes_left and len(self._pending) < self._max_requests: size = min(self._bytes_left, self._block_size) - task = asyncio.ensure_future(self.run_task(self._offset, size)) + task = asyncio.ensure_future(self._start_task(self._offset, size)) self._pending.add(task) self._offset += size self._bytes_left -= size - async def start(self) -> None: - """Start parallel I/O""" - - async def run_task(self, offset: int, size: int) -> None: + async def run_task(self, offset: int, size: int) -> Tuple[int, _T]: """Perform file I/O on a particular byte range""" raise NotImplementedError - async def finish(self) -> _T: - """Finish parallel I/O""" + async def iter(self) -> AsyncIterator[Tuple[int, _T]]: + """Perform file I/O and return async iterator of results""" - async def cleanup(self) -> None: - """Clean up parallel I/O""" + self._start_tasks() - async def run(self) -> _T: - """Perform all file I/O and return result or exception""" + while self._pending: + done, self._pending = await asyncio.wait( + self._pending, return_when=asyncio.FIRST_COMPLETED) - try: - await self.start() - - self._start_tasks() - - while self._pending: - done, self._pending = await asyncio.wait( - self._pending, return_when=asyncio.FIRST_COMPLETED) - - exceptions = [] + exceptions = [] - for task in done: - exc = task.exception() - - if exc and not isinstance(exc, SFTPEOFError): - exceptions.append(exc) + for task in done: + try: + offset, size, count, result = task.result() + yield offset, result - if exceptions: - for task in self._pending: - task.cancel() + if count and count < size: + self._pending.add(asyncio.ensure_future( + self._start_task(offset+count, size-count))) + except SFTPEOFError: + self._bytes_left = 0 + except (OSError, SFTPError) as exc: + exceptions.append(exc) - raise exceptions[0] + if exceptions: + for task in self._pending: + task.cancel() - self._start_tasks() + raise exceptions[0] - return await self.finish() - finally: - await self.cleanup() + self._start_tasks() class _SFTPFileReader(_SFTPParallelIO[bytes]): @@ -791,33 +781,32 @@ def __init__(self, block_size: int, max_requests: int, self._handler = handler self._handle = handle self._start = offset - self._data = bytearray() - async def run_task(self, offset: int, size: int) -> None: + async def run_task(self, offset: int, size: int) -> Tuple[int, bytes]: """Read a block of the file""" - while size: - data, _ = await self._handler.read(self._handle, offset, size) + data, _ = await self._handler.read(self._handle, offset, size) - pos = offset - self._start - pad = pos - len(self._data) + return len(data), data - if pad > 0: - self._data += pad * b'\0' + async def run(self) -> bytes: + """Reassemble and return data from parallel reads""" + + result = bytearray() - datalen = len(data) - self._data[pos:pos+datalen] = data + async for offset, data in self.iter(): + pos = offset - self._start + pad = pos - len(result) - offset += datalen - size -= datalen + if pad > 0: + result += pad * b'\0' - async def finish(self) -> bytes: - """Finish parallel read""" + result[pos:pos+len(data)] = data - return bytes(self._data) + return bytes(result) -class _SFTPFileWriter(_SFTPParallelIO[None]): +class _SFTPFileWriter(_SFTPParallelIO[int]): """Parallelized SFTP file writer""" def __init__(self, block_size: int, max_requests: int, @@ -830,15 +819,21 @@ def __init__(self, block_size: int, max_requests: int, self._start = offset self._data = data - async def run_task(self, offset: int, size: int) -> None: + async def run_task(self, offset: int, size: int) -> Tuple[int, int]: """Write a block to the file""" pos = offset - self._start await self._handler.write(self._handle, offset, self._data[pos:pos+size]) + return size, size + + async def run(self): + """Perform parallel writes""" + async for _ in self.iter(): + pass -class _SFTPFileCopier(_SFTPParallelIO): +class _SFTPFileCopier(_SFTPParallelIO[int]): """SFTP file copier This class parforms an SFTP file copy, initiating multiple @@ -846,11 +841,13 @@ class _SFTPFileCopier(_SFTPParallelIO): """ - def __init__(self, block_size: int, max_requests: int, offset: int, - total_bytes: int, srcfs: _SFTPFSProtocol, - dstfs: _SFTPFSProtocol, srcpath: bytes, dstpath: bytes, + def __init__(self, block_size: int, max_requests: int, total_bytes: int, + sparse: bool, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, + srcpath: bytes, dstpath: bytes, progress_handler: SFTPProgressHandler): - super().__init__(block_size, max_requests, offset, total_bytes) + super().__init__(block_size, max_requests, 0, 0) + + self._sparse = sparse self._srcfs = srcfs self._dstfs = dstfs @@ -865,51 +862,78 @@ def __init__(self, block_size: int, max_requests: int, offset: int, self._total_bytes = total_bytes self._progress_handler = progress_handler - async def start(self) -> None: - """Start parallel copy""" - - self._src = await self._srcfs.open(self._srcpath, 'rb') - self._dst = await self._dstfs.open(self._dstpath, 'wb') - - if self._progress_handler and self._total_bytes == 0: - self._progress_handler(self._srcpath, self._dstpath, 0, 0) - - async def run_task(self, offset: int, size: int) -> None: - """Copy the next block of the file""" + async def run_task(self, offset: int, size: int) -> Tuple[int, int]: + """Copy a block of the source file""" assert self._src is not None assert self._dst is not None - while size: - data = await self._src.read(size, offset) + data = await self._src.read(size, offset) + await self._dst.write(data, offset) + datalen = len(data) - if not data: - exc = SFTPFailure('Unexpected EOF during file copy') + return datalen, datalen - setattr(exc, 'filename', self._srcpath) - setattr(exc, 'offset', offset) + async def run(self) -> None: + """Perform parallel file copy""" - raise exc + async def _request_nonsparse_range(offset: int, length: int) -> \ + AsyncIterator[Tuple[int, int]]: + """Return the entire file as the range to copy""" - await self._dst.write(data, offset) + yield offset, length - datalen = len(data) + try: + self._src = await self._srcfs.open(self._srcpath, 'rb', + block_size=0) + self._dst = await self._dstfs.open(self._dstpath, 'wb', + block_size=0) - if self._progress_handler: - self._bytes_copied += datalen - self._progress_handler(self._srcpath, self._dstpath, - self._bytes_copied, self._total_bytes) + if self._progress_handler and self._total_bytes == 0: + self._progress_handler(self._srcpath, self._dstpath, 0, 0) + return - offset += datalen - size -= datalen + if self._sparse: + ranges = self._src.request_ranges(0, self._total_bytes) + else: + ranges = _request_nonsparse_range(0, self._total_bytes) + + if self._srcfs == self._dstfs and \ + isinstance(self._srcfs, SFTPClient) and \ + self._srcfs.supports_remote_copy: + async for offset, length in ranges: + await self._srcfs.remote_copy( + cast(SFTPClientFile, self._src), + cast(SFTPClientFile, self._dst), + offset, length, offset) + + self._bytes_copied += length + + if self._progress_handler: + self._progress_handler(self._srcpath, self._dstpath, + self._bytes_copied, + self._total_bytes) + else: + async for self._offset, self._bytes_left in ranges: + async for _, datalen in self.iter(): + self._bytes_copied += datalen - async def cleanup(self) -> None: - """Clean up parallel copy""" + if self._progress_handler and datalen != 0: + self._progress_handler(self._srcpath, self._dstpath, + self._bytes_copied, + self._total_bytes) - try: + if self._bytes_copied != self._total_bytes and not self._sparse: + exc = SFTPFailure('Unexpected EOF during file copy') + + setattr(exc, 'filename', self._srcpath) + setattr(exc, 'offset', self._bytes_copied) + + raise exc + finally: if self._src: # pragma: no branch await self._src.close() - finally: + if self._dst: # pragma: no branch await self._dst.close() @@ -960,7 +984,7 @@ def construct(packet: SSHPacket) -> Optional['SFTPError']: try: exc = _sftp_error_map[code](reason, lang) except KeyError: - exc = SFTPError(code, '%s (error %d)' % (reason, code), lang) + exc = SFTPError(code, f'{reason} (error {code})', lang) exc.decode(packet) return exc @@ -1423,11 +1447,11 @@ def __init__(self, reason: str, lang: str = DEFAULT_LANG): class SFTPInvalidParameter(SFTPError): """SFTP invalid parameter (SFTPv6+) - This exception is raised when paramters in a request are + This exception is raised when parameters in a request are out of range or incompatible with one another. :param reason: - Details about the invalid paramter + Details about the invalid parameter :param lang: (optional) The language the reason is in :type reason: `str` @@ -1713,7 +1737,7 @@ def _format(self, k: str, v: object) -> Optional[str]: return _file_types.get(cast(int, v), str(v)) \ if v != FILEXFER_TYPE_UNKNOWN else None elif k == 'permissions': - return '{:04o}'.format(cast(int, v)) + return f'{cast(int, v):04o}' elif k in ('atime', 'crtime', 'mtime', 'ctime'): return self._format_ns(k) elif k in ('atime_ns', 'crtime_ns', 'mtime_ns', 'ctime_ns'): @@ -1845,11 +1869,19 @@ def decode(cls, packet: SSHPacket, sftp_version: int) -> 'SFTPAttrs': flags = packet.get_uint32() attrs = cls() + # Work around a bug seen in a Huawei SFTP server where + # FILEXFER_ATTR_MODIFYTIME is included in flags, even though + # the SFTP version is set to 3. That flag is only defined for + # SFTPv4 and later. + if sftp_version == 3 and flags & (FILEXFER_ATTR_ACMODTIME | + FILEXFER_ATTR_MODIFYTIME): + flags &= ~FILEXFER_ATTR_MODIFYTIME + unsupported_attrs = flags & ~_valid_attr_flags[sftp_version] if unsupported_attrs: - raise SFTPBadMessage('Unsupported attribute flags: 0x%08x' % - unsupported_attrs) + raise SFTPBadMessage( + f'Unsupported attribute flags: 0x{unsupported_attrs:08x}') if sftp_version >= 4: attrs.type = packet.get_byte() @@ -1871,7 +1903,7 @@ def decode(cls, packet: SSHPacket, sftp_version: int) -> 'SFTPAttrs': try: attrs.owner = owner.decode('utf-8') except UnicodeDecodeError: - raise SFTPOwnerInvalid('Invalid owner name: %s' % + raise SFTPOwnerInvalid('Invalid owner name: ' + owner.decode('utf-8', 'backslashreplace')) from None group = packet.get_string() @@ -1879,7 +1911,7 @@ def decode(cls, packet: SSHPacket, sftp_version: int) -> 'SFTPAttrs': try: attrs.group = group.decode('utf-8') except UnicodeDecodeError: - raise SFTPGroupInvalid('Invalid group name: %s' % + raise SFTPGroupInvalid('Invalid group name: ' + group.decode('utf-8', 'backslashreplace')) from None if flags & FILEXFER_ATTR_PERMISSIONS: @@ -2119,6 +2151,277 @@ def decode(cls, packet: SSHPacket, sftp_version: int) -> 'SFTPName': return cls(filename, longname, attrs) +class SFTPLimits(Record): + """SFTP server limits + + SFTPLimits is a simple record class with the following fields: + + ================= ========================================= ====== + Field Description Type + ================= ========================================= ====== + max_packet_len Max allowed size of an SFTP packet uint64 + max_read_len Max allowed size of an SFTP read request uint64 + max_write_len Max allowed size of an SFTP write request uint64 + max_open_handles Max allowed number of open file handles uint64 + ================= ========================================= ====== + + """ + + max_packet_len: int + max_read_len: int + max_write_len: int + max_open_handles: int + + def encode(self) -> bytes: + """Encode SFTP server limits in an SSH packet""" + + return (UInt64(self.max_packet_len) + UInt64(self.max_read_len) + + UInt64(self.max_write_len) + UInt64(self.max_open_handles)) + + @classmethod + def decode(cls, packet: SSHPacket) -> Self: + """Decode bytes in an SSH packet as SFTP server limits""" + + max_packet_len = packet.get_uint64() + max_read_len = packet.get_uint64() + max_write_len = packet.get_uint64() + max_open_handles = packet.get_uint64() + + return cls(max_packet_len, max_read_len, + max_write_len, max_open_handles) + + def log(self, logger: SSHLogger, label: str) -> None: + """Log sending or receiving SFTP limits""" + + logger.debug1('%s erver limits:', label) + logger.debug1(' Max packet len: %d', self.max_packet_len) + logger.debug1(' Max read len: %d', self.max_read_len) + logger.debug1(' Max write len: %d', self.max_write_len) + logger.debug1(' Max open handles: %d', self.max_open_handles) + + +class SFTPRanges(Record): + """SFTP sparse file ranges""" + + ranges: List[Tuple[int, int]] + at_end: bool + + def encode(self) -> bytes: + """Encode sparse file ranges in an SSH packet""" + + return (UInt32(len(self.ranges)) + + b''.join((UInt64(offset) + UInt64(length) + for offset, length in self.ranges)) + + Boolean(self.at_end)) + + @classmethod + def decode(cls, packet: SSHPacket) -> Self: + """Decode bytes in an SSH packet as sparse file ranges""" + + count = packet.get_uint32() + ranges = [(packet.get_uint64(), packet.get_uint64()) + for _ in range(count)] + at_end = packet.get_boolean() + + return cls(ranges, at_end) + + def log(self, logger: SSHLogger, label: str) -> None: + """Log sending or receiving sparse file ranges""" + + logger.debug1('%s %s%s', label, + plural(len(self.ranges), 'sparse file range'), + ' (at end)' if self.at_end else '') + + for offset, length in self.ranges: + logger.debug1(' offset %d, length %d', offset, length) + + +class SFTPGlob: + """SFTP glob matcher""" + + def __init__(self, fs: _SFTPGlobProtocol, multiple=False): + self._fs = fs + self._multiple = multiple + self._prev_matches: Set[bytes] = set() + self._new_matches: List[SFTPName] = [] + self._matched = False + self._stat_cache: Dict[bytes, Optional[SFTPAttrs]] = {} + self._scandir_cache: Dict[bytes, List[SFTPName]] = {} + + def _split(self, pattern: bytes) -> Tuple[bytes, _SFTPPatList]: + """Split out exact parts of a glob pattern""" + + patlist: _SFTPPatList = [] + + if any(c in pattern for c in b'*?[]'): + path = b'' + plain: List[bytes] = [] + + for current in pattern.split(b'/'): + if any(c in current for c in b'*?[]'): + if plain: + if patlist: + patlist.append(plain) + else: + path = b'/'.join(plain) or b'/' + + plain = [] + + patlist.append(current) + else: + plain.append(current) + + if plain: + patlist.append(plain) + else: + path = pattern + + return path, patlist + + def _report_match(self, path, attrs): + """Report a matching name""" + + self._matched = True + + if self._multiple: + if path not in self._prev_matches: + self._prev_matches.add(path) + else: + return + + self._new_matches.append(SFTPName(path, attrs=attrs)) + + async def _stat(self, path) -> Optional[SFTPAttrs]: + """Cache results of calls to stat""" + + try: + return self._stat_cache[path] + except KeyError: + pass + + try: + attrs = await self._fs.stat(path) + except (SFTPNoSuchFile, SFTPPermissionDenied, SFTPNoSuchPath): + attrs = None + + self._stat_cache[path] = attrs + return attrs + + async def _scandir(self, path) -> AsyncIterator[SFTPName]: + """Cache results of calls to scandir""" + + try: + for entry in self._scandir_cache[path]: + yield entry + + return + except KeyError: + pass + + entries: List[SFTPName] = [] + + try: + async for entry in self._fs.scandir(path): + entries.append(entry) + yield entry + except (SFTPNoSuchFile, SFTPPermissionDenied, SFTPNoSuchPath): + pass + + self._scandir_cache[path] = entries + + async def _match_exact(self, path: bytes, pattern: Sequence[bytes], + patlist: _SFTPPatList) -> None: + """Match on an exact portion of a path""" + + newpath = posixpath.join(path, *pattern) + newpatlist = patlist[1:] + + attrs = await self._stat(newpath) + + if attrs is None: + return + + if newpatlist: + if attrs.type == FILEXFER_TYPE_DIRECTORY: + await self._match(newpath, attrs, newpatlist) + else: + self._report_match(newpath, attrs) + + async def _match_pattern(self, path: bytes, attrs: SFTPAttrs, + pattern: bytes, patlist: _SFTPPatList) -> None: + """Match on a pattern portion of a path""" + + newpatlist = patlist[1:] + + if pattern == b'**': + if newpatlist: + await self._match(path, attrs, newpatlist) + else: + self._report_match(path, attrs) + + async for entry in self._scandir(path or b'.'): + filename = cast(bytes, entry.filename) + + if filename in (b'.', b'..'): + continue + + if not pattern or fnmatch(filename, pattern): + newpath = posixpath.join(path, filename) + attrs = entry.attrs + + if pattern == b'**' and attrs.type == FILEXFER_TYPE_DIRECTORY: + await self._match(newpath, attrs, patlist) + elif newpatlist: + if attrs.type == FILEXFER_TYPE_DIRECTORY: + await self._match(newpath, attrs, newpatlist) + else: + self._report_match(newpath, attrs) + + async def _match(self, path: bytes, attrs: SFTPAttrs, + patlist: _SFTPPatList) -> None: + """Recursively match against a glob pattern""" + + pattern = patlist[0] + + if isinstance(pattern, list): + await self._match_exact(path, pattern, patlist) + else: + await self._match_pattern(path, attrs, pattern, patlist) + + async def match(self, pattern: bytes, + error_handler: SFTPErrorHandler = None, + sftp_version = MIN_SFTP_VERSION) -> Sequence[SFTPName]: + """Match against a glob pattern""" + + self._new_matches = [] + self._matched = False + + path, patlist = self._split(pattern) + + try: + attrs = await self._stat(path or b'.') + + if attrs: + if patlist: + if attrs.type == FILEXFER_TYPE_DIRECTORY: + await self._match(path, attrs, patlist) + elif path: + self._report_match(path, attrs) + + if pattern and not self._matched: + exc = SFTPNoSuchPath if sftp_version >= 4 else SFTPNoSuchFile + raise exc('No matches found') + except (OSError, SFTPError) as exc: + setattr(exc, 'srcpath', pattern) + + if error_handler: + error_handler(exc) + else: + raise + + return self._new_matches + + class SFTPHandler(SSHPacketLogger): """SFTP session handler""" @@ -2143,7 +2446,9 @@ class SFTPHandler(SSHPacketLogger): FXP_STAT: FXP_ATTRS, FXP_READLINK: FXP_NAME, b'statvfs@openssh.com': FXP_EXTENDED_REPLY, - b'fstatvfs@openssh.com': FXP_EXTENDED_REPLY + b'fstatvfs@openssh.com': FXP_EXTENDED_REPLY, + b'limits@openssh.com': FXP_EXTENDED_REPLY, + b'ranges@asyncssh.com': FXP_EXTENDED_REPLY } def __init__(self, reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]'): @@ -2151,6 +2456,8 @@ def __init__(self, reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]'): self._writer: Optional['SSHWriter[bytes]'] = writer self._logger = reader.logger.get_child('sftp') + self.limits = SFTPLimits(0, SAFE_SFTP_READ_LEN, SAFE_SFTP_WRITE_LEN, 0) + @property def logger(self) -> SSHLogger: """A logger associated with this SFTP handler""" @@ -2303,6 +2610,10 @@ def __init__(self, loop: asyncio.AbstractEventLoop, self._supports_fstatvfs = False self._supports_hardlink = False self._supports_fsync = False + self._supports_lsetstat = False + self._supports_limits = False + self._supports_copy_data = False + self._supports_ranges = False @property def version(self) -> int: @@ -2310,6 +2621,12 @@ def version(self) -> int: return self._version + @property + def supports_copy_data(self) -> bool: + """Return whether or not SFTP remote copy is supported""" + + return self._supports_copy_data + async def _cleanup(self, exc: Optional[Exception]) -> None: """Clean up this SFTP client session""" @@ -2365,7 +2682,7 @@ async def _make_request(self, pkttype: Union[int, bytes], return_type = self._return_types.get(pkttype) if resptype not in (FXP_STATUS, return_type): - raise SFTPBadMessage('Unexpected response type: %s' % resptype) + raise SFTPBadMessage(f'Unexpected response type: {resptype}') result = self._packet_handlers[resptype](self, resp) @@ -2425,7 +2742,7 @@ def _process_name(self, packet: SSHPacket) -> _SFTPNames: if self._version < 6: packet.check_end() - self.logger.debug1('Received %s%s', plural(len(names), 'name'), + self.logger.debug1('Received %s%s', plural(count, 'name'), ' (at end)' if at_end else '') for name in names: @@ -2484,7 +2801,7 @@ async def start(self) -> None: version = resp.get_uint32() if not MIN_SFTP_VERSION <= version <= MAX_SFTP_VERSION: - raise SFTPBadMessage('Unsupported version: %d' % version) + raise SFTPBadMessage(f'Unsupported version: {version}') rcvd_extensions: List[Tuple[bytes, bytes]] = [] @@ -2494,7 +2811,7 @@ async def start(self) -> None: rcvd_extensions.append((name, data)) except PacketDecodeError as exc: raise SFTPBadMessage(str(exc)) from None - except SFTPError as exc: + except SFTPError: raise except ConnectionLost as exc: raise SFTPConnectionLost(str(exc)) from None @@ -2519,6 +2836,14 @@ async def start(self) -> None: self._supports_hardlink = True elif name == b'fsync@openssh.com' and data == b'1': self._supports_fsync = True + elif name == b'lsetstat@openssh.com' and data == b'1': + self._supports_lsetstat = True + elif name == b'limits@openssh.com' and data == b'1': + self._supports_limits = True + elif name == b'copy-data' and data == b'1': + self._supports_copy_data = True + elif name == b'ranges@asyncssh.com' and data == b'1': + self._supports_ranges = True if version == 3: # Check if the server has a buggy SYMLINK implementation @@ -2532,6 +2857,24 @@ async def start(self) -> None: 'implementation') self._nonstandard_symlink = True + async def request_limits(self) -> None: + """Request SFTP server limits""" + + if self._supports_limits: + packet = cast(SSHPacket, await self._make_request( + b'limits@openssh.com')) + + limits = SFTPLimits.decode(packet) + packet.check_end() + + limits.log(self.logger, 'Received') + + if limits.max_read_len: + self.limits.max_read_len = limits.max_read_len + + if limits.max_write_len: + self.limits.max_write_len = limits.max_write_len + async def open(self, filename: bytes, pflags: int, attrs: SFTPAttrs) -> bytes: """Make an SFTP open request""" @@ -2596,27 +2939,34 @@ async def write(self, handle: bytes, offset: int, data: bytes) -> int: return cast(int, await self._make_request( FXP_WRITE, String(handle), UInt64(offset), String(data))) - async def stat(self, path: bytes, flags: int) -> SFTPAttrs: - """Make an SFTP stat request""" + async def stat(self, path: bytes, flags: int, *, + follow_symlinks: bool = True) -> SFTPAttrs: + """Make an SFTP stat or lstat request""" if self._version >= 4: flag_bytes = UInt32(flags) - flag_text = ', flags 0x%08x' % flags + flag_text = f', flags 0x{flags:08x}' else: flag_bytes = b'' flag_text = '' - self.logger.debug1('Sending stat for %s%s', path, flag_text) + if follow_symlinks: + self.logger.debug1('Sending stat for %s%s', path, flag_text) + + return cast(SFTPAttrs, await self._make_request( + FXP_STAT, String(path), flag_bytes)) + else: + self.logger.debug1('Sending lstat for %s%s', path, flag_text) - return cast(SFTPAttrs, await self._make_request( - FXP_STAT, String(path), flag_bytes)) + return cast(SFTPAttrs, await self._make_request( + FXP_LSTAT, String(path), flag_bytes)) async def lstat(self, path: bytes, flags: int) -> SFTPAttrs: """Make an SFTP lstat request""" if self._version >= 4: flag_bytes = UInt32(flags) - flag_text = ', flags 0x%08x' % flags + flag_text = f', flags 0x{flags:08x}' else: flag_bytes = b'' flag_text = '' @@ -2631,7 +2981,7 @@ async def fstat(self, handle: bytes, flags: int) -> SFTPAttrs: if self._version >= 4: flag_bytes = UInt32(flags) - flag_text = ', flags 0x%08x' % flags + flag_text = f', flags 0x{flags:08x}' else: flag_bytes = b'' flag_text = '' @@ -2642,13 +2992,24 @@ async def fstat(self, handle: bytes, flags: int) -> SFTPAttrs: return cast(SFTPAttrs, await self._make_request( FXP_FSTAT, String(handle), flag_bytes)) - async def setstat(self, path: bytes, attrs: SFTPAttrs) -> None: - """Make an SFTP setstat request""" + async def setstat(self, path: bytes, attrs: SFTPAttrs, *, + follow_symlinks: bool = True) -> None: + """Make an SFTP setstat or lsetstat request""" - self.logger.debug1('Sending setstat for %s%s', path, hide_empty(attrs)) + if follow_symlinks: + self.logger.debug1('Sending setstat for %s%s', + path, hide_empty(attrs)) - await self._make_request(FXP_SETSTAT, String(path), - attrs.encode(self._version)) + await self._make_request(FXP_SETSTAT, String(path), + attrs.encode(self._version)) + elif self._supports_lsetstat: + self.logger.debug1('Sending lsetstat for %s%s', + path, hide_empty(attrs)) + + await self._make_request(b'lsetstat@openssh.com', String(path), + attrs.encode(self._version)) + else: + raise SFTPOpUnsupported('lsetstat not supported by server') async def fsetstat(self, handle: bytes, attrs: SFTPAttrs) -> None: """Make an SFTP fsetstat request""" @@ -2707,7 +3068,7 @@ async def rename(self, oldpath: bytes, newpath: bytes, flags: int) -> None: if self._version >= 5: self.logger.debug1('Sending rename request from %s to %s%s', - oldpath, newpath, ', flags=0x%x' % flags + oldpath, newpath, f', flags=0x{flags:x}' if flags else '') await self._make_request(FXP_RENAME, String(oldpath), @@ -2784,12 +3145,12 @@ async def realpath(self, path: bytes, *compose_paths: bytes, checkmsg = '' else: try: - checkmsg = ', check=%s' % self._realpath_check_names[check] + checkmsg = f', check={self._realpath_check_names[check]}' except KeyError: - checkmsg = ', check=%d' % check + checkmsg = f', check={check}' self.logger.debug1('Sending realpath of %s%s%s', path, - b', compose_path: %s' % b', '.join(compose_paths) + b', compose_path: ' + b', '.join(compose_paths) if compose_paths else b'', checkmsg) if self._version >= 6: @@ -2880,6 +3241,48 @@ async def fsync(self, handle: bytes) -> None: else: raise SFTPOpUnsupported('fsync not supported') + async def copy_data(self, read_from_handle: bytes, read_from_offset: int, + read_from_length: int, write_to_handle: bytes, + write_to_offset: int) -> None: + """Make an SFTP copy data request""" + + if self._supports_copy_data: + self.logger.debug1('Sending copy-data from handle %s, ' + 'offset %d, length %d to handle %s, ' + 'offset %d', read_from_handle.hex(), + read_from_offset, read_from_length, + write_to_handle.hex(), write_to_offset) + + await self._make_request(b'copy-data', String(read_from_handle), + UInt64(read_from_offset), + UInt64(read_from_length), + String(write_to_handle), + UInt64(write_to_offset)) + else: + raise SFTPOpUnsupported('copy-data not supported') + + async def request_ranges(self, handle: bytes, offset: int, + length: int) -> SFTPRanges: + """Request file ranges containing data in a remote file""" + + if self._supports_ranges: + self.logger.debug1('Sending ranges request for handle %s, ' + 'offset %d, length %d', handle.hex(), + offset, length) + + packet = cast(SSHPacket, await self._make_request( + b'ranges@asyncssh.com', String(handle), + UInt64(offset), UInt64(length))) + + result = SFTPRanges.decode(packet) + packet.check_end() + + result.log(self.logger, 'Received') + + return result + else: + return SFTPRanges([(offset, length)], True) + def exit(self) -> None: """Handle a request to close the SFTP session""" @@ -2911,11 +3314,23 @@ def __init__(self, handler: SFTPClientHandler, handle: bytes, self._appending = appending self._encoding = encoding self._errors = errors - self._block_size = block_size - self._max_requests = max_requests self._offset = None if appending else 0 - async def __aenter__(self) -> 'SFTPClientFile': + self.read_len = \ + handler.limits.max_read_len if block_size == -1 else block_size + self.write_len = \ + handler.limits.max_write_len if block_size == -1 else block_size + + if max_requests <= 0: + if self.read_len: + max_requests = max(16, min(MAX_SFTP_READ_LEN // + self.read_len, 128)) + else: + max_requests = 1 + + self._max_requests = max_requests + + async def __aenter__(self) -> Self: """Allow SFTPClientFile to be used as an async context manager""" return self @@ -2928,30 +3343,144 @@ async def __aexit__(self, _exc_type: Optional[Type[BaseException]], await self.close() return False + @property + def handle(self) -> bytes: + """Return handle or raise an error if clsoed""" + + if self._handle is None: + raise ValueError('I/O operation on closed file') + + return self._handle + async def _end(self) -> int: """Return the offset of the end of the file""" attrs = await self.stat() return attrs.size or 0 - async def read(self, size: int = -1, - offset: Optional[int] = None) -> AnyStr: - """Read data from the remote file + async def request_ranges(self, offset: int, length: int) -> \ + AsyncIterator[Tuple[int, int]]: + """Return file ranges containing data in a remote file""" + + next_offset = offset + next_length = length + end = offset + length + at_end = False + + try: + while not at_end: + result = await self._handler.request_ranges( + self.handle, next_offset, next_length) + + if result.ranges: + # pylint: disable=undefined-loop-variable + + for range_offset, range_length in result.ranges: + yield range_offset, range_length + + next_offset = range_offset + range_length + next_length = end - next_offset + else: # pragma: no cover + break + + at_end = result.at_end + except SFTPEOFError: + pass + + async def read(self, size: int = -1, + offset: Optional[int] = None) -> AnyStr: + """Read data from the remote file + + This method reads and returns up to `size` bytes of data + from the remote file. If size is negative, all data up to + the end of the file is returned. + + If offset is specified, the read will be performed starting + at that offset rather than the current file position. This + argument should be provided if you want to issue parallel + reads on the same file, since the file position is not + predictable in that case. + + Data will be returned as a string if an encoding was set when + the file was opened. Otherwise, data is returned as bytes. + + An empty `str` or `bytes` object is returned when at EOF. + + :param size: + The number of bytes to read + :param offset: (optional) + The offset from the beginning of the file to begin reading + :type size: `int` + :type offset: `int` + + :returns: data read from the file, as a `str` or `bytes` + + :raises: | :exc:`ValueError` if the file has been closed + | :exc:`UnicodeDecodeError` if the data can't be + decoded using the requested encoding + | :exc:`SFTPError` if the server returns an error + + """ + + if self._handle is None: + raise ValueError('I/O operation on closed file') + + if offset is None: + offset = self._offset + + # If self._offset is None, we're appending and haven't sought + # backward in the file since the last write, so there's no + # data to return + + data = b'' + + if offset is not None: + if size is None or size < 0: + size = (await self._end()) - offset + + try: + if self.read_len and size > \ + min(self.read_len, self._handler.limits.max_read_len): + data = await _SFTPFileReader( + self.read_len, self._max_requests, self._handler, + self._handle, offset, size).run() + else: + data, _ = await self._handler.read(self._handle, + offset, size) + + self._offset = offset + len(data) + except SFTPEOFError: + pass + + if self._encoding: + return cast(AnyStr, data.decode(self._encoding, self._errors)) + else: + return cast(AnyStr, data) + + async def read_parallel(self, size: int = -1, + offset: Optional[int] = None) -> \ + AsyncIterator[Tuple[int, bytes]]: + """Read parallel blocks of data from the remote file This method reads and returns up to `size` bytes of data from the remote file. If size is negative, all data up to the end of the file is returned. If offset is specified, the read will be performed starting - at that offset rather than the current file position. This - argument should be provided if you want to issue parallel - reads on the same file, since the file position is not - predictable in that case. + at that offset rather than the current file position. - Data will be returned as a string if an encoding was set when - the file was opened. Otherwise, data is returned as bytes. + Data is returned as a series of tuples delivered by an + async iterator, where each tuple contains an offset and + data bytes. Encoding is ignored here, since multi-byte + characters may be split across block boundaries. - An empty `str` or `bytes` object is returned when at EOF. + To maximize performance, multiple reads are issued in + parallel, and data blocks may be returned out of order. + The size of the blocks and the maximum number of + outstanding read requests can be controlled using + the `block_size` and `max_requests` arguments passed + in the call to the :meth:`open() ` + method on the :class:`SFTPClient` class. :param size: The number of bytes to read @@ -2960,11 +3489,9 @@ async def read(self, size: int = -1, :type size: `int` :type offset: `int` - :returns: data read from the file, as a `str` or `bytes` + :returns: an async iterator of tuples of offset and data bytes :raises: | :exc:`ValueError` if the file has been closed - | :exc:`UnicodeDecodeError` if the data can't be - decoded using the requested encoding | :exc:`SFTPError` if the server returns an error """ @@ -2975,33 +3502,20 @@ async def read(self, size: int = -1, if offset is None: offset = self._offset - # If self._offset is None, we're appending and haven't seeked + # If self._offset is None, we're appending and haven't sought # backward in the file since the last write, so there's no # data to return - data = b'' - if offset is not None: if size is None or size < 0: size = (await self._end()) - offset - - try: - if self._block_size and size > self._block_size: - data = await _SFTPFileReader( - self._block_size, self._max_requests, self._handler, - self._handle, offset, size).run() - else: - data, _ = await self._handler.read(self._handle, - offset, size) - - self._offset = offset + len(data) - except SFTPEOFError: - pass - - if self._encoding: - return cast(AnyStr, data.decode(self._encoding, self._errors)) else: - return cast(AnyStr, data) + offset = 0 + size = 0 + + return _SFTPFileReader(self.read_len, self._max_requests, + self._handler, self._handle, offset, + size).iter() async def write(self, data: AnyStr, offset: Optional[int] = None) -> int: """Write data to the remote file @@ -3046,9 +3560,9 @@ async def write(self, data: AnyStr, offset: Optional[int] = None) -> int: datalen = len(data_bytes) - if self._block_size and datalen > self._block_size: + if self.write_len and datalen > self.write_len: await _SFTPFileWriter( - self._block_size, self._max_requests, self._handler, + self.write_len, self._max_requests, self._handler, self._handle, offset, data_bytes).run() else: await self._handler.write(self._handle, offset, data_bytes) @@ -3319,7 +3833,7 @@ def __init__(self, handler: SFTPClientHandler, self._path_errors = path_errors self._cwd: Optional[bytes] = None - async def __aenter__(self) -> 'SFTPClient': + async def __aenter__(self) -> Self: """Allow SFTPClient to be used as an async context manager""" return self @@ -3345,6 +3859,18 @@ def version(self) -> int: return self._handler.version + @property + def limits(self) -> SFTPLimits: + """:class:`SFTPLimits` associated with this SFTP session""" + + return self._handler.limits + + @property + def supports_remote_copy(self) -> bool: + """Return whether or not SFTP remote copy is supported""" + + return self._handler.supports_copy_data + @staticmethod def basename(path: bytes) -> bytes: """Return the final component of a POSIX-style path""" @@ -3415,27 +3941,28 @@ async def _type(self, path: _SFTPPath, return FILEXFER_TYPE_UNKNOWN async def _copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, - srcpath: bytes, dstpath: bytes, preserve: bool, - recurse: bool, follow_symlinks: bool, block_size: int, - max_requests: int, progress_handler: SFTPProgressHandler, - error_handler: SFTPErrorHandler) -> None: + srcpath: bytes, dstpath: bytes, srcattrs: SFTPAttrs, + preserve: bool, recurse: bool, follow_symlinks: bool, + sparse: bool, block_size: int, max_requests: int, + progress_handler: SFTPProgressHandler, + error_handler: SFTPErrorHandler, + remote_only: bool) -> None: """Copy a file, directory, or symbolic link""" try: - if follow_symlinks: - srcattrs = await srcfs.stat(srcpath) - else: - srcattrs = await srcfs.lstat(srcpath) - filetype = srcattrs.type + if follow_symlinks and filetype == FILEXFER_TYPE_SYMLINK: + srcattrs = await srcfs.stat(srcpath) + filetype = srcattrs.type + if filetype == FILEXFER_TYPE_DIRECTORY: if not recurse: exc = SFTPFileIsADirectory if self.version >= 6 \ else SFTPFailure - raise exc('%s is a directory' % - srcpath.decode('utf-8', 'backslashreplace')) + raise exc(srcpath.decode('utf-8', 'backslashreplace') + + ' is a directory') self.logger.info(' Starting copy of directory %s to %s', srcpath, dstpath) @@ -3443,19 +3970,20 @@ async def _copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, if not await dstfs.isdir(dstpath): await dstfs.mkdir(dstpath) - names = await srcfs.listdir(srcpath) + async for srcname in srcfs.scandir(srcpath): + filename = cast(bytes, srcname.filename) - for name in names: - if name in (b'.', b'..'): + if filename in (b'.', b'..'): continue - srcfile = posixpath.join(srcpath, name) - dstfile = posixpath.join(dstpath, name) + srcfile = posixpath.join(srcpath, filename) + dstfile = posixpath.join(dstpath, filename) await self._copy(srcfs, dstfs, srcfile, dstfile, - preserve, recurse, follow_symlinks, - block_size, max_requests, - progress_handler, error_handler) + srcname.attrs, preserve, recurse, + follow_symlinks, sparse, block_size, + max_requests, progress_handler, + error_handler, remote_only) self.logger.info(' Finished copy of directory %s to %s', srcpath, dstpath) @@ -3470,20 +3998,31 @@ async def _copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, else: self.logger.info(' Copying file %s to %s', srcpath, dstpath) - await _SFTPFileCopier(block_size, max_requests, 0, - srcattrs.size or 0, srcfs, dstfs, - srcpath, dstpath, progress_handler).run() + if remote_only and not self.supports_remote_copy: + raise SFTPOpUnsupported('Remote copy not supported') + + await _SFTPFileCopier(block_size, max_requests, + srcattrs.size or 0, sparse, + srcfs, dstfs, srcpath, dstpath, + progress_handler).run() if preserve: - attrs = await srcfs.stat(srcpath) + attrs = await srcfs.stat(srcpath, + follow_symlinks=follow_symlinks) attrs = SFTPAttrs(permissions=attrs.permissions, atime=attrs.atime, atime_ns=attrs.atime_ns, mtime=attrs.mtime, mtime_ns=attrs.mtime_ns) - self.logger.info(' Preserving attrs: %s', attrs) + try: + await dstfs.setstat(dstpath, attrs, + follow_symlinks=follow_symlinks or + filetype != FILEXFER_TYPE_SYMLINK) + + self.logger.info(' Preserved attrs: %s', attrs) + except SFTPOpUnsupported: + self.logger.info(' Preserving symlink attrs unsupported') - await dstfs.setstat(dstpath, attrs) except (OSError, SFTPError) as exc: setattr(exc, 'srcpath', srcpath) setattr(exc, 'dstpath', dstpath) @@ -3494,37 +4033,44 @@ async def _copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, raise async def _begin_copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, - srcpaths: Sequence[_SFTPPath], - dstpath: Optional[_SFTPPath], + srcpaths: _SFTPPaths, dstpath: Optional[_SFTPPath], copy_type: str, expand_glob: bool, preserve: bool, - recurse: bool, follow_symlinks: bool, + recurse: bool, follow_symlinks: bool, sparse: bool, block_size: int, max_requests: int, progress_handler: SFTPProgressHandler, - error_handler: SFTPErrorHandler) -> None: + error_handler: SFTPErrorHandler, + remote_only: bool = False) -> None: """Begin a new file upload, download, or copy""" - if isinstance(srcpaths, tuple): - srcpaths = list(srcpaths) + if block_size <= 0: + block_size = min(srcfs.limits.max_read_len, + dstfs.limits.max_write_len) - self.logger.info('Starting SFTP %s of %s to %s', - copy_type, srcpaths, dstpath) + if max_requests <= 0: + max_requests = max(16, min(MAX_SFTP_READ_LEN // block_size, 128)) if isinstance(srcpaths, (bytes, str, PurePath)): srcpaths = [srcpaths] + elif not isinstance(srcpaths, list): + srcpaths = list(srcpaths) + + self.logger.info('Starting SFTP %s of %s to %s', + copy_type, srcpaths, dstpath) - exppaths: List[bytes] + srcnames: List[SFTPName] = [] if expand_glob: - exppaths = [] + glob = SFTPGlob(srcfs, len(srcpaths) > 1) - for pattern in srcpaths: - if not pattern: - continue - - exppaths.extend(await match_glob(srcfs, srcfs.encode(pattern), + for srcpath in srcpaths: + srcnames.extend(await glob.match(srcfs.encode(srcpath), error_handler, self.version)) else: - exppaths = [srcfs.encode(srcfile) for srcfile in srcpaths] + for srcpath in srcpaths: + srcpath = srcfs.encode(srcpath) + srcattrs = await srcfs.stat(srcpath, + follow_symlinks=follow_symlinks) + srcnames.append(SFTPName(srcpath, attrs=srcattrs)) if dstpath: dstpath = dstfs.encode(dstpath) @@ -3533,33 +4079,34 @@ async def _begin_copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, dst_isdir = dstpath is None or (await dstfs.isdir(dstpath)) - if len(exppaths) > 1 and not dst_isdir: + if len(srcnames) > 1 and not dst_isdir: assert dstpath is not None exc = SFTPNotADirectory if self.version >= 6 else SFTPFailure - raise exc('%s must be a directory' % - dstpath.decode('utf-8', 'backslashreplace')) + raise exc(dstpath.decode('utf-8', 'backslashreplace') + + ' must be a directory') - for srcfile in exppaths: - filename = srcfs.basename(srcfile) + for srcname in srcnames: + srcfile = cast(bytes, srcname.filename) + basename = srcfs.basename(srcfile) if dstpath is None: - dstfile = filename + dstfile = basename elif dst_isdir: - dstfile = dstfs.compose_path(filename, parent=dstpath) + dstfile = dstfs.compose_path(basename, parent=dstpath) else: dstfile = dstpath - await self._copy(srcfs, dstfs, srcfile, dstfile, preserve, - recurse, follow_symlinks, block_size, - max_requests, progress_handler, error_handler) + await self._copy(srcfs, dstfs, srcfile, dstfile, srcname.attrs, + preserve, recurse, follow_symlinks, sparse, + block_size, max_requests, progress_handler, + error_handler, remote_only) - async def get(self, remotepaths: Sequence[_SFTPPath], + async def get(self, remotepaths: _SFTPPaths, localpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, - follow_symlinks: bool = False, - block_size: int = SFTP_BLOCK_SIZE, - max_requests: int = _MAX_SFTP_REQUESTS, + follow_symlinks: bool = False, sparse: bool = True, + block_size: int = -1, max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None) -> None: """Download remote files @@ -3595,10 +4142,14 @@ async def get(self, remotepaths: Sequence[_SFTPPath], watch out for links that result in loops. The block_size argument specifies the size of read and write - requests issued when downloading the files, defaulting to 16 KB. + requests issued when downloading the files, defaulting to + the maximum allowed by the server, or 16 KB if the server + doesn't advertise limits. The max_requests argument specifies the maximum number of - parallel read or write requests issued, defaulting to 128. + parallel read or write requests issued, defaulting to a + value between 16 and 128 depending on the selected block + size to avoid excessive memory usage. If progress_handler is specified, it will be called after each block of a file is successfully downloaded. The arguments @@ -3628,6 +4179,8 @@ async def get(self, remotepaths: Sequence[_SFTPPath], Whether or not to recursively copy directories :param follow_symlinks: (optional) Whether or not to follow symbolic links + :param sparse: (optional) + Whether or not to do a sparse file copy where it is supported :param block_size: (optional) The block size to use for file reads and writes :param max_requests: (optional) @@ -3644,6 +4197,7 @@ async def get(self, remotepaths: Sequence[_SFTPPath], :type preserve: `bool` :type recurse: `bool` :type follow_symlinks: `bool` + :type sparse: `bool` :type block_size: `int` :type max_requests: `int` :type progress_handler: `callable` @@ -3656,15 +4210,14 @@ async def get(self, remotepaths: Sequence[_SFTPPath], await self._begin_copy(self, local_fs, remotepaths, localpath, 'get', False, preserve, recurse, follow_symlinks, - block_size, max_requests, progress_handler, - error_handler) + sparse, block_size, max_requests, + progress_handler, error_handler) - async def put(self, localpaths: Sequence[_SFTPPath], + async def put(self, localpaths: _SFTPPaths, remotepath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, - follow_symlinks: bool = False, - block_size: int = SFTP_BLOCK_SIZE, - max_requests: int = _MAX_SFTP_REQUESTS, + follow_symlinks: bool = False, sparse: bool = True, + block_size: int = -1, max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None) -> None: """Upload local files @@ -3700,10 +4253,14 @@ async def put(self, localpaths: Sequence[_SFTPPath], watch out for links that result in loops. The block_size argument specifies the size of read and write - requests issued when uploading the files, defaulting to 16 KB. + requests issued when uploading the files, defaulting to + the maximum allowed by the server, or 16 KB if the server + doesn't advertise limits. The max_requests argument specifies the maximum number of - parallel read or write requests issued, defaulting to 128. + parallel read or write requests issued, defaulting to a + value between 16 and 128 depending on the selected block + size to avoid excessive memory usage. If progress_handler is specified, it will be called after each block of a file is successfully uploaded. The arguments @@ -3733,6 +4290,8 @@ async def put(self, localpaths: Sequence[_SFTPPath], Whether or not to recursively copy directories :param follow_symlinks: (optional) Whether or not to follow symbolic links + :param sparse: (optional) + Whether or not to do a sparse file copy where it is supported :param block_size: (optional) The block size to use for file reads and writes :param max_requests: (optional) @@ -3749,6 +4308,7 @@ async def put(self, localpaths: Sequence[_SFTPPath], :type preserve: `bool` :type recurse: `bool` :type follow_symlinks: `bool` + :type sparse: `bool` :type block_size: `int` :type max_requests: `int` :type progress_handler: `callable` @@ -3761,17 +4321,17 @@ async def put(self, localpaths: Sequence[_SFTPPath], await self._begin_copy(local_fs, self, localpaths, remotepath, 'put', False, preserve, recurse, follow_symlinks, - block_size, max_requests, progress_handler, - error_handler) + sparse, block_size, max_requests, + progress_handler, error_handler) - async def copy(self, srcpaths: Sequence[_SFTPPath], + async def copy(self, srcpaths: _SFTPPaths, dstpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, - follow_symlinks: bool = False, - block_size: int =SFTP_BLOCK_SIZE, - max_requests: int = _MAX_SFTP_REQUESTS, + follow_symlinks: bool = False, sparse: bool = True, + block_size: int = -1, max_requests: int = -1, progress_handler: SFTPProgressHandler = None, - error_handler: SFTPErrorHandler = None) -> None: + error_handler: SFTPErrorHandler = None, + remote_only: bool = False) -> None: """Copy remote files to a new location This method copies one or more files or directories on the @@ -3805,10 +4365,14 @@ async def copy(self, srcpaths: Sequence[_SFTPPath], watch out for links that result in loops. The block_size argument specifies the size of read and write - requests issued when copying the files, defaulting to 16 KB. + requests issued when copying the files, defaulting to the + maximum allowed by the server, or 16 KB if the server + doesn't advertise limits. The max_requests argument specifies the maximum number of - parallel read or write requests issued, defaulting to 128. + parallel read or write requests issued, defaulting to a + value between 16 and 128 depending on the selected block + size to avoid excessive memory usage. If progress_handler is specified, it will be called after each block of a file is successfully copied. The arguments @@ -3838,6 +4402,8 @@ async def copy(self, srcpaths: Sequence[_SFTPPath], Whether or not to recursively copy directories :param follow_symlinks: (optional) Whether or not to follow symbolic links + :param sparse: (optional) + Whether or not to do a sparse file copy where it is supported :param block_size: (optional) The block size to use for file reads and writes :param max_requests: (optional) @@ -3846,6 +4412,8 @@ async def copy(self, srcpaths: Sequence[_SFTPPath], The function to call to report copy progress :param error_handler: (optional) The function to call when an error occurs + :param remote_only: (optional) + Whether or not to only allow this to be a remote copy :type srcpaths: :class:`PurePath `, `str`, or `bytes`, or a sequence of these @@ -3854,10 +4422,12 @@ async def copy(self, srcpaths: Sequence[_SFTPPath], :type preserve: `bool` :type recurse: `bool` :type follow_symlinks: `bool` + :type sparse: `bool` :type block_size: `int` :type max_requests: `int` :type progress_handler: `callable` :type error_handler: `callable` + :type remote_only: `bool` :raises: | :exc:`OSError` if a local file I/O error occurs | :exc:`SFTPError` if the server returns an error @@ -3866,15 +4436,14 @@ async def copy(self, srcpaths: Sequence[_SFTPPath], await self._begin_copy(self, self, srcpaths, dstpath, 'remote copy', False, preserve, recurse, follow_symlinks, - block_size, max_requests, progress_handler, - error_handler) + sparse, block_size, max_requests, + progress_handler, error_handler, remote_only) - async def mget(self, remotepaths: Sequence[_SFTPPath], + async def mget(self, remotepaths: _SFTPPaths, localpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, - follow_symlinks: bool = False, - block_size: int = SFTP_BLOCK_SIZE, - max_requests: int = _MAX_SFTP_REQUESTS, + follow_symlinks: bool = False, sparse: bool = True, + block_size: int = -1, max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None) -> None: """Download remote files with glob pattern match @@ -3890,15 +4459,14 @@ async def mget(self, remotepaths: Sequence[_SFTPPath], await self._begin_copy(self, local_fs, remotepaths, localpath, 'mget', True, preserve, recurse, follow_symlinks, - block_size, max_requests, progress_handler, - error_handler) + sparse, block_size, max_requests, + progress_handler, error_handler) - async def mput(self, localpaths: Sequence[_SFTPPath], + async def mput(self, localpaths: _SFTPPaths, remotepath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, - follow_symlinks: bool = False, - block_size: int = SFTP_BLOCK_SIZE, - max_requests: int = _MAX_SFTP_REQUESTS, + follow_symlinks: bool = False, sparse: bool = True, + block_size: int = -1, max_requests: int = -1, progress_handler: SFTPProgressHandler = None, error_handler: SFTPErrorHandler = None) -> None: """Upload local files with glob pattern match @@ -3914,18 +4482,18 @@ async def mput(self, localpaths: Sequence[_SFTPPath], await self._begin_copy(local_fs, self, localpaths, remotepath, 'mput', True, preserve, recurse, follow_symlinks, - block_size, max_requests, progress_handler, - error_handler) + sparse, block_size, max_requests, + progress_handler, error_handler) - async def mcopy(self, srcpaths: Sequence[_SFTPPath], + async def mcopy(self, srcpaths: _SFTPPaths, dstpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, - follow_symlinks: bool = False, - block_size: int =SFTP_BLOCK_SIZE, - max_requests: int = _MAX_SFTP_REQUESTS, + follow_symlinks: bool = False, sparse: bool = True, + block_size: int = -1, max_requests: int = -1, progress_handler: SFTPProgressHandler = None, - error_handler: SFTPErrorHandler = None) -> None: - """Download remote files with glob pattern match + error_handler: SFTPErrorHandler = None, + remote_only: bool = False) -> None: + """Copy remote files with glob pattern match This method copies files and directories on the remote system matching one or more glob patterns. @@ -3938,10 +4506,49 @@ async def mcopy(self, srcpaths: Sequence[_SFTPPath], await self._begin_copy(self, self, srcpaths, dstpath, 'remote mcopy', True, preserve, recurse, follow_symlinks, - block_size, max_requests, progress_handler, - error_handler) + sparse, block_size, max_requests, + progress_handler, error_handler, remote_only) + + async def remote_copy(self, src: _SFTPClientFileOrPath, + dst: _SFTPClientFileOrPath, src_offset: int = 0, + src_length: int = 0, dst_offset: int = 0) -> None: + """Copy data between remote files + + :param src: + The remote file object to read data from + :param dst: + The remote file object to write data to + :param src_offset: (optional) + The offset to begin reading data from + :param src_length: (optional) + The number of bytes to attempt to copy + :param dst_offset: (optional) + The offset to begin writing data to + :type src: + :class:`SFTPClientFile`, :class:`PurePath `, + `str`, or `bytes` + :type dst: + :class:`SFTPClientFile`, :class:`PurePath `, + `str`, or `bytes` + :type src_offset: `int` + :type src_length: `int` + :type dst_offset: `int` + + :raises: :exc:`SFTPError` if the server doesn't support this + extension or returns an error + + """ + + if isinstance(src, (bytes, str, PurePath)): + src = await self.open(src, 'rb', block_size=0) + + if isinstance(dst, (bytes, str, PurePath)): + dst = await self.open(dst, 'wb', block_size=0) - async def glob(self, patterns: Union[_SFTPPath, Sequence[_SFTPPath]], + await self._handler.copy_data(src.handle, src_offset, src_length, + dst.handle, dst_offset) + + async def glob(self, patterns: _SFTPPaths, error_handler: SFTPErrorHandler = None) -> \ Sequence[BytesOrStr]: """Match remote files against glob patterns @@ -3985,27 +4592,36 @@ async def glob(self, patterns: Union[_SFTPPath, Sequence[_SFTPPath]], """ + return [name.filename for name in + await self.glob_sftpname(patterns, error_handler)] + + async def glob_sftpname(self, patterns: _SFTPPaths, + error_handler: SFTPErrorHandler = None) -> \ + Sequence[SFTPName]: + """Match glob patterns and return SFTPNames + + This method is similar to :meth:`glob`, but it returns matching + file names and attributes as :class:`SFTPName` objects. + + """ + if isinstance(patterns, (bytes, str, PurePath)): patterns = [patterns] - result: List[BytesOrStr] = [] + glob = SFTPGlob(self, len(patterns) > 1) + matches: List[SFTPName] = [] for pattern in patterns: - if not pattern: - continue - - enc_names = await match_glob(self, self.encode(pattern), - error_handler, self.version) + new_matches = await glob.match(self.encode(pattern), + error_handler, self.version) if isinstance(pattern, (str, PurePath)): - names = [self.decode(name) for name in enc_names] - else: - names = cast(List[BytesOrStr], enc_names) - - result.extend(names) + for name in new_matches: + name.filename = self.decode(cast(bytes, name.filename)) - return result + matches.extend(new_matches) + return matches async def makedirs(self, path: _SFTPPath, attrs: SFTPAttrs = SFTPAttrs(), exist_ok: bool = False) -> None: @@ -4037,8 +4653,13 @@ async def makedirs(self, path: _SFTPPath, attrs: SFTPAttrs = SFTPAttrs(), path = self.encode(path) curpath = b'/' if posixpath.isabs(path) else (self._cwd or b'') exists = True + parts = path.split(b'/') + last = len(parts) - 1 + + exc: Union[Type[SFTPNotADirectory], Type[SFTPFailure], + Type[SFTPFileAlreadyExists]] - for part in path.split(b'/'): + for i, part in enumerate(parts): curpath = posixpath.join(curpath, part) try: @@ -4053,13 +4674,16 @@ async def makedirs(self, path: _SFTPPath, attrs: SFTPAttrs = SFTPAttrs(), exc = SFTPNotADirectory if self.version >= 6 \ else SFTPFailure - raise exc('%s is not a directory' % curpath_str) from None + raise exc(f'{curpath_str} is not a directory') from None + except SFTPPermissionDenied: + if i == last: + raise if exists and not exist_ok: exc = SFTPFileAlreadyExists if self.version >= 6 else SFTPFailure - raise exc('%s already exists' % - curpath.decode('utf-8', 'backslashreplace')) + raise exc(curpath.decode('utf-8', 'backslashreplace') + + ' already exists') async def rmtree(self, path: _SFTPPath, ignore_errors: bool = False, onerror: _SFTPOnErrorHandler = None) -> None: @@ -4154,8 +4778,8 @@ def onerror(*_args: object) -> None: try: if await self.islink(path): - raise SFTPNoSuchFile('%s must not be a symlink' % - path.decode('utf-8', 'backslashreplace')) + raise SFTPNoSuchFile(path.decode('utf-8', 'backslashreplace') + + ' must not be a symlink') except SFTPError: onerror(self.islink, path, sys.exc_info()) return @@ -4167,8 +4791,8 @@ async def open(self, path: _SFTPPath, pflags_or_mode: Union[int, str] = FXF_READ, attrs: SFTPAttrs = SFTPAttrs(), encoding: Optional[str] = 'utf-8', errors: str = 'strict', - block_size: int = SFTP_BLOCK_SIZE, - max_requests: int = _MAX_SFTP_REQUESTS) -> SFTPClientFile: + block_size: int = -1, + max_requests: int = -1) -> SFTPClientFile: """Open a remote file This method opens a remote file and returns an @@ -4220,7 +4844,7 @@ async def open(self, path: _SFTPPath, Most applications should be able to use this method regardless of the version of the SFTP protocol negotiated with the SFTP server. A conversion from the pflags_or_mode values to the - SFTPv5/v6 flag values will happen automaitcally. However, if + SFTPv5/v6 flag values will happen automatically. However, if an application wishes to set flags only available in SFTPv5/v6, the :meth:`open56` method may be used to specify these flags explicitly. @@ -4234,17 +4858,19 @@ async def open(self, path: _SFTPPath, or write call will become a single request to the SFTP server. Otherwise, read or write calls larger than this size will be turned into parallel requests to the server of the requested - size, defaulting to 16 KB. + size, defaulting to the maximum allowed by the server, or 16 KB + if the server doesn't advertise limits. .. note:: The OpenSSH SFTP server will close the connection - if it receives a message larger than 256 KB, and - limits read requests to returning no more than - 64 KB. So, when connecting to an OpenSSH SFTP - server, it is recommended that the block_size be - set below these sizes. + if it receives a message larger than 256 KB. So, + when connecting to an OpenSSH SFTP server, it is + recommended that the block_size be left at its + default of using the server-advertised limits. The max_requests argument specifies the maximum number of - parallel read or write requests issued, defaulting to 128. + parallel read or write requests issued, defaulting to a + value between 16 and 128 depending on the selected block + size to avoid excessive memory usage. :param path: The name of the remote file to open @@ -4299,8 +4925,8 @@ async def open56(self, path: _SFTPPath, flags: int = FXF_OPEN_EXISTING, attrs: SFTPAttrs = SFTPAttrs(), encoding: Optional[str] = 'utf-8', errors: str = 'strict', - block_size: int = SFTP_BLOCK_SIZE, - max_requests: int = _MAX_SFTP_REQUESTS) -> SFTPClientFile: + block_size: int = -1, + max_requests: int = -1) -> SFTPClientFile: """Open a remote file using SFTP v5/v6 flags This method is very similar to :meth:`open`, but the pflags_or_mode @@ -4388,21 +5014,24 @@ async def open56(self, path: _SFTPPath, flags & FXF_APPEND_DATA), encoding, errors, block_size, max_requests) - async def stat(self, path: _SFTPPath, - flags = FILEXFER_ATTR_DEFINED_V4) -> SFTPAttrs: - """Get attributes of a remote file or directory, following symlinks + async def stat(self, path: _SFTPPath, flags = FILEXFER_ATTR_DEFINED_V4, *, + follow_symlinks: bool = True) -> SFTPAttrs: + """Get attributes of a remote file, directory, or symlink - This method queries the attributes of a remote file or - directory. If the path provided is a symbolic link, the - returned attributes will correspond to the target of the - link. + This method queries the attributes of a remote file, directory, + or symlink. If the path provided is a symlink and follow_symlinks + is `True`, the returned attributes will correspond to the target + of the link. :param path: The path of the remote file or directory to get attributes for :param flags: (optional) Flags indicating attributes of interest (SFTPv4 only) + :param follow_symlinks: (optional) + Whether or not to follow symbolic links :type path: :class:`PurePath `, `str`, or `bytes` :type flags: `int` + :type follow_symlinks: `bool` :returns: An :class:`SFTPAttrs` containing the file attributes @@ -4411,7 +5040,8 @@ async def stat(self, path: _SFTPPath, """ path = self.compose_path(path) - return await self._handler.stat(path, flags) + return await self._handler.stat(path, flags, + follow_symlinks=follow_symlinks) async def lstat(self, path: _SFTPPath, flags = FILEXFER_ATTR_DEFINED_V4) -> SFTPAttrs: @@ -4439,14 +5069,15 @@ async def lstat(self, path: _SFTPPath, path = self.compose_path(path) return await self._handler.lstat(path, flags) - async def setstat(self, path: _SFTPPath, attrs: SFTPAttrs) -> None: - """Set attributes of a remote file or directory + async def setstat(self, path: _SFTPPath, attrs: SFTPAttrs, *, + follow_symlinks: bool = True) -> None: + """Set attributes of a remote file, directory, or symlink - This method sets attributes of a remote file or directory. - If the path provided is a symbolic link, the attributes - will be set on the target of the link. A subset of the - fields in `attrs` can be initialized and only those - attributes will be changed. + This method sets attributes of a remote file, directory, or + symlink. If the path provided is a symlink and follow_symlinks + is `True`, the attributes will be set on the target of the link. + A subset of the fields in `attrs` can be initialized and only + those attributes will be changed. :param path: The path of the remote file or directory to set attributes for @@ -4460,7 +5091,9 @@ async def setstat(self, path: _SFTPPath, attrs: SFTPAttrs) -> None: """ path = self.compose_path(path) - await self._handler.setstat(path, attrs) + + await self._handler.setstat(path, attrs, + follow_symlinks=follow_symlinks) async def statvfs(self, path: _SFTPPath) -> SFTPVFSAttrs: """Get attributes of a remote file system @@ -4504,20 +5137,23 @@ async def truncate(self, path: _SFTPPath, size: int) -> None: await self.setstat(path, SFTPAttrs(size=size)) @overload - async def chown(self, path: _SFTPPath, - uid: int, gid: int) -> None: ... # pragma: no cover + async def chown(self, path: _SFTPPath, uid: int, gid: int, *, + follow_symlinks: bool = True) -> \ + None: ... # pragma: no cover @overload - async def chown(self, path: _SFTPPath, - owner: str, group: str) -> None: ... # pragma: no cover + async def chown(self, path: _SFTPPath, owner: str, group: str, *, + follow_symlinks: bool = True) -> \ + None: ... # pragma: no cover async def chown(self, path, uid_or_owner = None, gid_or_group = None, - uid = None, gid = None, owner = None, group = None): - """Change the owner user and group id of a remote file or directory + uid = None, gid = None, owner = None, group = None, *, + follow_symlinks = True): + """Change the owner of a remote file, directory, or symlink - This method changes the user and group id of a remote - file or directory. If the path provided is a symbolic - link, the target of the link will be changed. + This method changes the user and group id of a remote file, + directory, or symlink. If the path provided is a symlink and + follow_symlinks is `True`, the target of the link will be changed. :param path: The path of the remote file to change @@ -4529,11 +5165,14 @@ async def chown(self, path, uid_or_owner = None, gid_or_group = None, The new owner to assign to the file (SFTPv4 only) :param group: The new group to assign to the file (SFTPv4 only) + :param follow_symlinks: (optional) + Whether or not to follow symbolic links :type path: :class:`PurePath `, `str`, or `bytes` :type uid: `int` :type gid: `int` :type owner: `str` :type group: `str` + :type follow_symlinks: `bool` :raises: :exc:`SFTPError` if the server returns an error @@ -4550,39 +5189,46 @@ async def chown(self, path, uid_or_owner = None, gid_or_group = None, group = gid_or_group await self.setstat(path, SFTPAttrs(uid=uid, gid=gid, - owner=owner, group=group)) + owner=owner, group=group), + follow_symlinks=follow_symlinks) - async def chmod(self, path: _SFTPPath, mode: int) -> None: - """Change the file permissions of a remote file or directory + async def chmod(self, path: _SFTPPath, mode: int, *, + follow_symlinks: bool = True) -> None: + """Change the permissions of a remote file, directory, or symlink - This method changes the permissions of a remote file or - directory. If the path provided is a symbolic link, the - target of the link will be changed. + This method changes the permissions of a remote file, directory, + or symlink. If the path provided is a symlink and follow_symlinks + is `True`, the target of the link will be changed. :param path: The path of the remote file to change :param mode: The new file permissions, expressed as an int + :param follow_symlinks: (optional) + Whether or not to follow symbolic links :type path: :class:`PurePath `, `str`, or `bytes` :type mode: `int` + :type follow_symlinks: `bool` :raises: :exc:`SFTPError` if the server returns an error """ - await self.setstat(path, SFTPAttrs(permissions=mode)) + await self.setstat(path, SFTPAttrs(permissions=mode), + follow_symlinks=follow_symlinks) async def utime(self, path: _SFTPPath, times: Optional[Tuple[float, float]] = None, - ns: Optional[Tuple[int, int]] = None) -> None: - """Change the access and modify times of a remote file or directory + ns: Optional[Tuple[int, int]] = None, *, + follow_symlinks: bool = True) -> None: + """Change the timestamps of a remote file, directory, or symlink - This method changes the access and modify times of a - remote file or directory. If neither `times` nor '`ns` is - provided, the times will be changed to the current time. + This method changes the access and modify times of a remote file, + directory, or symlink. If neither `times` nor '`ns` is provided, + the times will be changed to the current time. - If the path provided is a symbolic link, the target of the - link will be changed. + If the path provided is a symlink and follow_symlinks is `True`, + the target of the link will be changed. :param path: The path of the remote file to change @@ -4592,15 +5238,19 @@ async def utime(self, path: _SFTPPath, :param ns: (optional) The new access and modify times, as nanoseconds relative to the UNIX epoch + :param follow_symlinks: (optional) + Whether or not to follow symbolic links :type path: :class:`PurePath `, `str`, or `bytes` :type times: tuple of two `int` or `float` values :type ns: tuple of two `int` values + :type follow_symlinks: `bool` :raises: :exc:`SFTPError` if the server returns an error """ - await self.setstat(path, _utime_to_attrs(times, ns)) + await self.setstat(path, _utime_to_attrs(times, ns), + follow_symlinks=follow_symlinks) async def exists(self, path: _SFTPPath) -> bool: """Return if the remote path exists and isn't a broken symbolic link @@ -4613,7 +5263,7 @@ async def exists(self, path: _SFTPPath) -> bool: """ - return (await self._type(path)) != FILEXFER_TYPE_UNKNOWN + return await self._type(path) != FILEXFER_TYPE_UNKNOWN async def lexists(self, path: _SFTPPath) -> bool: """Return if the remote path exists, without following symbolic links @@ -4626,7 +5276,7 @@ async def lexists(self, path: _SFTPPath) -> bool: """ - return (await self._type(path, statfunc=self.lstat)) != \ + return await self._type(path, statfunc=self.lstat) != \ FILEXFER_TYPE_UNKNOWN async def getatime(self, path: _SFTPPath) -> Optional[float]: @@ -4755,7 +5405,7 @@ async def isdir(self, path: _SFTPPath) -> bool: """ - return (await self._type(path)) == FILEXFER_TYPE_DIRECTORY + return await self._type(path) == FILEXFER_TYPE_DIRECTORY async def isfile(self, path: _SFTPPath) -> bool: """Return if the remote path refers to a regular file @@ -4768,7 +5418,7 @@ async def isfile(self, path: _SFTPPath) -> bool: """ - return (await self._type(path)) == FILEXFER_TYPE_REGULAR + return await self._type(path) == FILEXFER_TYPE_REGULAR async def islink(self, path: _SFTPPath) -> bool: """Return if the remote path refers to a symbolic link @@ -4781,7 +5431,7 @@ async def islink(self, path: _SFTPPath) -> bool: """ - return (await self._type(path, statfunc=self.lstat)) == \ + return await self._type(path, statfunc=self.lstat) == \ FILEXFER_TYPE_SYMLINK async def remove(self, path: _SFTPPath) -> None: @@ -4814,7 +5464,7 @@ async def rename(self, oldpath: _SFTPPath, newpath: _SFTPPath, .. note:: By default, this version of rename will not overwrite the new path if it already exists. However, this can be controlled using the `flags` argument, - available in SFTPv5 and later. Whan a connection + available in SFTPv5 and later. When a connection is negotiated to use an earliler version of SFTP and `flags` is set, this method will attempt to fall back to the OpenSSH "posix-rename" extension @@ -4876,7 +5526,7 @@ async def posix_rename(self, oldpath: _SFTPPath, await self._handler.posix_rename(oldpath, newpath) async def scandir(self, path: _SFTPPath = '.') -> AsyncIterator[SFTPName]: - """Return an async iterator of the contents of a remote directory + """Return names and attributes of the files in a remote directory This method reads the contents of a directory, returning the names and attributes of what is contained there as an @@ -4943,7 +5593,7 @@ async def listdir(self, path: bytes) -> \ Sequence[bytes]: ... # pragma: no cover @overload - async def listdir(self, path: FilePath) -> \ + async def listdir(self, path: FilePath = ...) -> \ Sequence[str]: ... # pragma: no cover async def listdir(self, path: _SFTPPath = '.') -> Sequence[BytesOrStr]: @@ -5170,8 +5820,8 @@ async def symlink(self, oldpath: _SFTPPath, newpath: _SFTPPath) -> None: """ - oldpath = self.compose_path(oldpath) - newpath = self.encode(newpath) + oldpath = self.encode(oldpath) + newpath = self.compose_path(newpath) await self._handler.symlink(oldpath, newpath) async def link(self, oldpath: _SFTPPath, newpath: _SFTPPath) -> None: @@ -5219,7 +5869,7 @@ async def wait_closed(self) -> None: class SFTPServerHandler(SFTPHandler): """An SFTP server session handler""" - # Supported attribute flags in setstat/fsetstat + # Supported attribute flags in setstat/fsetstat/lsetstat _supported_attr_mask = FILEXFER_ATTR_SIZE | \ FILEXFER_ATTR_PERMISSIONS | \ FILEXFER_ATTR_ACCESSTIME | \ @@ -5233,7 +5883,7 @@ class SFTPServerHandler(SFTPHandler): # Supported SFTPv5/v6 open flags _supported_open_flags = FXF_ACCESS_DISPOSITION | FXF_APPEND_DATA - # Supported SFTPv5/v6 desired accesss flags + # Supported SFTPv5/v6 desired access flags _supported_access_mask = ACE4_READ_DATA | ACE4_WRITE_DATA | \ ACE4_APPEND_DATA | ACE4_READ_ATTRIBUTES | \ ACE4_WRITE_ATTRIBUTES @@ -5249,7 +5899,11 @@ class SFTPServerHandler(SFTPHandler): (b'vendor-id', _vendor_id), (b'posix-rename@openssh.com', b'1'), (b'hardlink@openssh.com', b'1'), - (b'fsync@openssh.com', b'1')] + (b'fsync@openssh.com', b'1'), + (b'lsetstat@openssh.com', b'1'), + (b'limits@openssh.com', b'1'), + (b'copy-data', b'1'), + (b'ranges@asyncssh.com', b'1')] _attrib_extensions: List[bytes] = [] @@ -5266,7 +5920,7 @@ def __init__(self, server: 'SFTPServer', reader: 'SSHReader[bytes]', self._nonstandard_symlink = False self._next_handle = 0 self._file_handles: Dict[bytes, object] = {} - self._dir_handles: Dict[bytes, List[SFTPName]] = {} + self._dir_handles: Dict[bytes, AsyncIterator[SFTPName]] = {} async def _cleanup(self, exc: Optional[Exception]) -> None: """Clean up this SFTP server session""" @@ -5276,10 +5930,12 @@ async def _cleanup(self, exc: Optional[Exception]) -> None: result = self._server.close(file_obj) if inspect.isawaitable(result): - assert result is not None await result - self._server.exit() + result = self._server.exit() + + if inspect.isawaitable(result): + await result self._file_handles = {} self._dir_handles = {} @@ -5312,8 +5968,7 @@ async def _process_packet(self, pkttype: int, pktid: int, handler = self._packet_handlers.get(handler_type) if not handler: - raise SFTPOpUnsupported('Unsupported request type: %s' % - pkttype) + raise SFTPOpUnsupported(f'Unsupported request type: {pkttype}') return_type = self._return_types.get(handler_type, FXP_STATUS) result = await handler(self, packet) @@ -5352,6 +6007,12 @@ async def _process_packet(self, pkttype: int, pktid: int, response = (UInt32(len(names)) + b''.join(name.encode(self._version) for name in names) + end) + elif isinstance(result, SFTPLimits): + result.log(self.logger, 'Sending') + response = result.encode() + elif isinstance(result, SFTPRanges): + result.log(self.logger, 'Sending') + response = result.encode() else: attrs: _SupportsEncode @@ -5382,7 +6043,7 @@ async def _process_packet(self, pkttype: int, pktid: int, str(exc.reason)) response = exc.encode(self._version) - except NotImplementedError as exc: + except NotImplementedError: assert handler is not None return_type = FXP_STATUS @@ -5391,7 +6052,7 @@ async def _process_packet(self, pkttype: int, pktid: int, self.logger.debug1('Sending operation not supported: %s', op_name) response = (UInt32(FX_OP_UNSUPPORTED) + - String('Operation not supported: %s' % op_name) + + String(f'Operation not supported: {op_name}') + String(DEFAULT_LANG)) except OSError as exc: return_type = FXP_STATUS @@ -5441,7 +6102,7 @@ async def _process_packet(self, pkttype: int, pktid: int, response = SFTPError(code, reason).encode(self._version) except Exception as exc: # pragma: no cover return_type = FXP_STATUS - reason = 'Uncaught exception: %s' % str(exc) + reason = f'Uncaught exception: {exc}' self.logger.debug1('Sending failure: %s', reason, exc_info=sys.exc_info) @@ -5459,10 +6120,10 @@ async def _process_open(self, packet: SSHPacket) -> bytes: if self._version >= 5: desired_access = packet.get_uint32() flags = packet.get_uint32() - flagmsg = 'access=0x%04x, flags=0x%04x' % (desired_access, flags) + flagmsg = f'access=0x{desired_access:04x}, flags=0x{flags:04x}' else: pflags = packet.get_uint32() - flagmsg = 'pflags=0x%02x' % pflags + flagmsg = f'pflags=0x{pflags:02x}' attrs = SFTPAttrs.decode(packet, self._version) @@ -5476,14 +6137,14 @@ async def _process_open(self, packet: SSHPacket) -> bytes: unsupported_access = desired_access & ~self._supported_access_mask if unsupported_access: - raise SFTPInvalidParameter('Unsupported access flags: 0x%08x' % - unsupported_access) + raise SFTPInvalidParameter( + f'Unsupported access flags: 0x{unsupported_access:08x}') unsupported_flags = flags & ~self._supported_open_flags if unsupported_flags: - raise SFTPInvalidParameter('Unsupported open flags: 0x%08x' % - unsupported_flags) + raise SFTPInvalidParameter( + f'Unsupported open flags: 0x{unsupported_flags:08x}') result = self._server.open56(path, desired_access, flags, attrs) else: @@ -5511,7 +6172,6 @@ async def _process_close(self, packet: SSHPacket) -> None: result = self._server.close(file_obj) if inspect.isawaitable(result): - assert result is not None await result return @@ -5545,16 +6205,8 @@ async def _process_read(self, packet: SSHPacket) -> Tuple[bytes, bool]: result: bytes if self._version >= 6: - attr_result = self._server.fstat(file_obj) - - if inspect.isawaitable(attr_result): - attr_result = await cast(Awaitable[_SFTPOSAttrs], - attr_result) - - if isinstance(attr_result, os.stat_result): - attrs = SFTPAttrs.from_local(attr_result) - else: - attrs = cast(SFTPAttrs, attr_result) + attrs = await self._server.convert_attrs( + self._server.fstat(file_obj)) at_end = offset + len(result) == attrs.size else: @@ -5603,7 +6255,7 @@ async def _process_lstat(self, packet: SSHPacket) -> _SFTPOSAttrs: packet.check_end() self.logger.debug1('Received lstat for %s%s', path, - ', flags=0x%08x' % flags if flags else '') + f', flags=0x{flags:08x}' if flags else '') # Ignore flags for now, returning all available fields @@ -5627,7 +6279,7 @@ async def _process_fstat(self, packet: SSHPacket) -> _SFTPOSAttrs: packet.check_end() self.logger.debug1('Received fstat for handle %s%s', handle.hex(), - ', flags=0x%08x' % flags if flags else '') + f', flags=0x{flags:08x}' if flags else '') file_obj = self._file_handles.get(handle) @@ -5658,7 +6310,6 @@ async def _process_setstat(self, packet: SSHPacket) -> None: result = self._server.setstat(path, attrs) if inspect.isawaitable(result): - assert result is not None await result async def _process_fsetstat(self, packet: SSHPacket) -> None: @@ -5679,7 +6330,6 @@ async def _process_fsetstat(self, packet: SSHPacket) -> None: result = self._server.fsetstat(file_obj, attrs) if inspect.isawaitable(result): - assert result is not None await result else: raise SFTPInvalidHandle('Invalid file handle') @@ -5694,46 +6344,8 @@ async def _process_opendir(self, packet: SSHPacket) -> bytes: self.logger.debug1('Received opendir for %s', path) - listdir_result = self._server.listdir(path) - - if inspect.isawaitable(listdir_result): - listdir_result = await cast(Awaitable[Sequence[bytes]], - listdir_result) - - listdir_result: Sequence[Union[bytes, SFTPName]] - entries = list(listdir_result) - - for i, entry in enumerate(entries): - if isinstance(entry, bytes): - entries[i] = entry = SFTPName(entry) - - filename = os.path.join(path, cast(bytes, entry.filename)) - attr_result = self._server.lstat(filename) - - if inspect.isawaitable(attr_result): - attr_result = await cast(Awaitable[_SFTPOSAttrs], - attr_result) - - attr_result: _SFTPOSAttrs - - if isinstance(attr_result, os.stat_result): - attr_result = SFTPAttrs.from_local(attr_result) - - attr_result: SFTPAttrs - - entry.attrs = attr_result - - if not entry.longname and self._version == 3: - longname_result = self._server.format_longname(entry) - - if inspect.isawaitable(longname_result): - assert longname_result is not None - await longname_result - - entries: List[SFTPName] - handle = self._get_next_handle() - self._dir_handles[handle] = entries + self._dir_handles[handle] = self._server.scandir(path) return handle async def _process_readdir(self, packet: SSHPacket) -> _SFTPNames: @@ -5747,12 +6359,28 @@ async def _process_readdir(self, packet: SSHPacket) -> _SFTPNames: self.logger.debug1('Received readdir for handle %s', handle.hex()) names = self._dir_handles.get(handle) + if names: - result = names[:_MAX_READDIR_NAMES] - del names[:_MAX_READDIR_NAMES] - return result, not names - elif names is not None: - raise SFTPEOFError + count = 0 + result: List[SFTPName] = [] + + async for name in names: + if not name.longname and self._version == 3: + longname_result = self._server.format_longname(name) + + if inspect.isawaitable(longname_result): + await longname_result + + result.append(name) + count += 1 + + if count == _MAX_READDIR_NAMES: + break + + if result: + return result, count < _MAX_READDIR_NAMES + else: + raise SFTPEOFError else: raise SFTPInvalidHandle('Invalid file handle') @@ -5769,7 +6397,6 @@ async def _process_remove(self, packet: SSHPacket) -> None: result = self._server.remove(path) if inspect.isawaitable(result): - assert result is not None await result async def _process_mkdir(self, packet: SSHPacket) -> None: @@ -5786,7 +6413,6 @@ async def _process_mkdir(self, packet: SSHPacket) -> None: result = self._server.mkdir(path, attrs) if inspect.isawaitable(result): - assert result is not None await result async def _process_rmdir(self, packet: SSHPacket) -> None: @@ -5802,7 +6428,6 @@ async def _process_rmdir(self, packet: SSHPacket) -> None: result = self._server.rmdir(path) if inspect.isawaitable(result): - assert result is not None await result async def _process_realpath(self, packet: SSHPacket) -> _SFTPNames: @@ -5820,14 +6445,14 @@ async def _process_realpath(self, packet: SSHPacket) -> _SFTPNames: compose_paths.append(packet.get_string()) try: - checkmsg = ', check=%s' % self._realpath_check_names[check] + checkmsg = f', check={self._realpath_check_names[check]}' except KeyError: raise SFTPInvalidParameter('Invalid check value') from None else: check = FXRP_NO_CHECK self.logger.debug1('Received realpath for %s%s%s', path, - b', compose_path: %s' % b', '.join(compose_paths) + b', compose_path: ' + b', '.join(compose_paths) if compose_paths else b'', checkmsg) for cpath in compose_paths: @@ -5844,16 +6469,8 @@ async def _process_realpath(self, packet: SSHPacket) -> _SFTPNames: if check != FXRP_NO_CHECK: try: - attr_result = self._server.stat(result) - - if inspect.isawaitable(attr_result): - attr_result = await cast(Awaitable[_SFTPOSAttrs], - attr_result) - - if isinstance(attr_result, os.stat_result): - attrs = SFTPAttrs.from_local(attr_result) - else: - attrs = cast(SFTPAttrs, attr_result) + attrs = await self._server.convert_attrs( + self._server.stat(result)) except (OSError, SFTPError): if check == FXRP_STAT_ALWAYS: raise @@ -5871,7 +6488,7 @@ async def _process_stat(self, packet: SSHPacket) -> _SFTPOSAttrs: packet.check_end() self.logger.debug1('Received stat for %s%s', path, - ', flags=0x%08x' % flags if flags else '') + f', flags=0x{flags:08x}' if flags else '') # Ignore flags for now, returning all available fields result = self._server.stat(path) @@ -5891,7 +6508,7 @@ async def _process_rename(self, packet: SSHPacket) -> None: if self._version >= 5: flags = packet.get_uint32() - flag_text = ', flags=0x%08x' % flags + flag_text = f', flags=0x{flags:08x}' else: flags = 0 flag_text = '' @@ -5908,7 +6525,6 @@ async def _process_rename(self, packet: SSHPacket) -> None: result = self._server.rename(oldpath, newpath) if inspect.isawaitable(result): - assert result is not None await result async def _process_readlink(self, packet: SSHPacket) -> _SFTPNames: @@ -5948,7 +6564,6 @@ async def _process_symlink(self, packet: SSHPacket) -> None: result = self._server.symlink(oldpath, newpath) if inspect.isawaitable(result): - assert result is not None await result async def _process_link(self, packet: SSHPacket) -> None: @@ -5970,7 +6585,6 @@ async def _process_link(self, packet: SSHPacket) -> None: result = self._server.link(oldpath, newpath) if inspect.isawaitable(result): - assert result is not None await result async def _process_lock(self, packet: SSHPacket) -> None: @@ -5992,7 +6606,6 @@ async def _process_lock(self, packet: SSHPacket) -> None: result = self._server.lock(file_obj, offset, length, flags) if inspect.isawaitable(result): # pragma: no branch - assert result is not None await result else: raise SFTPInvalidHandle('Invalid file handle') @@ -6014,7 +6627,6 @@ async def _process_unlock(self, packet: SSHPacket) -> None: result = self._server.unlock(file_obj, offset, length) if inspect.isawaitable(result): # pragma: no branch - assert result is not None await result else: raise SFTPInvalidHandle('Invalid file handle') @@ -6032,7 +6644,6 @@ async def _process_posix_rename(self, packet: SSHPacket) -> None: result = self._server.posix_rename(oldpath, newpath) if inspect.isawaitable(result): - assert result is not None await result async def _process_statvfs(self, packet: SSHPacket) -> _SFTPOSVFSAttrs: @@ -6087,7 +6698,6 @@ async def _process_openssh_link(self, packet: SSHPacket) -> None: result = self._server.link(oldpath, newpath) if inspect.isawaitable(result): - assert result is not None await result async def _process_fsync(self, packet: SSHPacket) -> None: @@ -6104,11 +6714,120 @@ async def _process_fsync(self, packet: SSHPacket) -> None: result = self._server.fsync(file_obj) if inspect.isawaitable(result): - assert result is not None await result else: raise SFTPInvalidHandle('Invalid file handle') + async def _process_lsetstat(self, packet: SSHPacket) -> None: + """Process an incoming SFTP lsetstat request""" + + path = packet.get_string() + attrs = SFTPAttrs.decode(packet, self._version) + + if self._version < 6: + packet.check_end() + + self.logger.debug1('Received lsetstat for %s%s', + path, hide_empty(attrs)) + + result = self._server.lsetstat(path, attrs) + + if inspect.isawaitable(result): + await result + + async def _process_limits(self, packet: SSHPacket) -> SFTPLimits: + """Process an incoming SFTP limits request""" + + packet.check_end() + + nfiles = os.sysconf('SC_OPEN_MAX') - 5 if hasattr(os, 'sysconf') else 0 + + return SFTPLimits(MAX_SFTP_PACKET_LEN, MAX_SFTP_READ_LEN, + MAX_SFTP_WRITE_LEN, nfiles) + + async def _process_copy_data(self, packet: SSHPacket) -> None: + """Process an incoming copy data request""" + + read_from_handle = packet.get_string() + read_from_offset = packet.get_uint64() + read_from_length = packet.get_uint64() + write_to_handle = packet.get_string() + write_to_offset = packet.get_uint64() + packet.check_end() + + self.logger.debug1('Received copy-data from handle %s, ' + 'offset %d, length %d to handle %s, ' + 'offset %d', read_from_handle.hex(), + read_from_offset, read_from_length, + write_to_handle.hex(), write_to_offset) + + src = self._file_handles.get(read_from_handle) + dst = self._file_handles.get(write_to_handle) + + if src and dst: + read_to_end = read_from_length == 0 + + while read_to_end or read_from_length: + if read_to_end: + size = _COPY_DATA_BLOCK_SIZE + else: + size = min(read_from_length, _COPY_DATA_BLOCK_SIZE) + + data = self._server.read(src, read_from_offset, size) + + if inspect.isawaitable(data): + data = await data + + data: bytes + + result = self._server.write(dst, write_to_offset, data) + + if inspect.isawaitable(result): + await result + + if len(data) < size: + break + + read_from_offset += size + write_to_offset += size + + if not read_to_end: + read_from_length -= size + else: + raise SFTPInvalidHandle('Invalid file handle') + + async def _process_ranges(self, packet: SSHPacket) -> SFTPRanges: + """Process an incoming sparse file ranges request""" + + handle = packet.get_string() + offset = packet.get_uint64() + length = packet.get_uint64() + packet.check_end() + + self.logger.debug1('Received ranges request for handle %s, ' + 'offset %d, length %d', handle.hex(), + offset, length) + + file_obj = cast(_SFTPFileObj, self._file_handles.get(handle)) + + if file_obj: + count = 0 + result: List[Tuple[int, int]] = [] + + async for data_range in _request_ranges(file_obj, offset, length): + result.append(data_range) + count += 1 + + if count == _MAX_SPARSE_RANGES: + break + + if result: + return SFTPRanges(result, count < _MAX_SPARSE_RANGES) + else: + raise SFTPEOFError + else: + raise SFTPInvalidHandle('Invalid file handle') + _packet_handlers: Dict[Union[int, bytes], _SFTPPacketHandler] = { FXP_OPEN: _process_open, FXP_CLOSE: _process_close, @@ -6135,7 +6854,11 @@ async def _process_fsync(self, packet: SSHPacket) -> None: b'statvfs@openssh.com': _process_statvfs, b'fstatvfs@openssh.com': _process_fstatvfs, b'hardlink@openssh.com': _process_openssh_link, - b'fsync@openssh.com': _process_fsync + b'fsync@openssh.com': _process_fsync, + b'lsetstat@openssh.com': _process_lsetstat, + b'limits@openssh.com': _process_limits, + b'copy-data': _process_copy_data, + b'ranges@asyncssh.com': _process_ranges } async def run(self) -> None: @@ -6188,7 +6911,7 @@ async def run(self) -> None: UInt32(self._supported_attrib_mask) + \ UInt32(self._supported_open_flags) + \ UInt32(self._supported_access_mask) + \ - UInt32(_MAX_SFTP_READ_SIZE) + ext_names + \ + UInt32(MAX_SFTP_READ_LEN) + ext_names + \ attrib_ext_names extensions.append((b'supported', supported)) @@ -6199,7 +6922,7 @@ async def run(self) -> None: UInt32(self._supported_attrib_mask) + \ UInt32(self._supported_open_flags) + \ UInt32(self._supported_access_mask) + \ - UInt32(_MAX_SFTP_READ_SIZE) + \ + UInt32(MAX_SFTP_READ_LEN) + \ UInt16(self._supported_open_block_vector) + \ UInt16(self._supported_block_vector) + \ UInt32(len(self._attrib_extensions)) + \ @@ -6251,7 +6974,7 @@ class SFTPServer: .. note:: Any method can optionally be defined as a coroutine if that method needs to perform - blocking opertions to determine its result. + blocking operations to determine its result. The `chan` object provided here is the :class:`SSHServerChannel` instance this SFTP server is associated with. It can be queried to @@ -6322,6 +7045,29 @@ def logger(self) -> SSHLogger: return self._chan.logger + async def convert_attrs(self, result: MaybeAwait[_SFTPOSAttrs]) -> \ + SFTPAttrs: + """Convert stat result to SFTPAttrs""" + + if inspect.isawaitable(result): + result = await cast(Awaitable[_SFTPOSAttrs], result) + + result: _SFTPOSAttrs + + if isinstance(result, os.stat_result): + result = SFTPAttrs.from_local(result) + + result: SFTPAttrs + + return result + + async def _to_sftpname(self, parent: bytes, name: bytes) -> SFTPName: + """Construct an SFTPName for a filename in a directory""" + + path = posixpath.join(parent, name) + attrs = await self.convert_attrs(self.lstat(path)) + return SFTPName(name, attrs=attrs) + def format_user(self, uid: Optional[int]) -> str: """Return the user name associated with a uid @@ -6408,8 +7154,8 @@ def format_longname(self, name: SFTPName) -> MaybeAwait[None]: else: modtime = '' - detail = '{:10s} {:>4s} {:8s} {:8s} {:>8s} {:12s} '.format( - mode, nlink, user, group, size, modtime) + detail = f'{mode:10s} {nlink:>4s} {user:8s} {group:8s} ' \ + f'{size:>8s} {modtime:12s} ' name.longname = detail.encode('utf-8') + cast(bytes, name.filename) @@ -6546,8 +7292,14 @@ def open(self, path: bytes, pflags: int, attrs: SFTPAttrs) -> \ pass perms = 0o666 if attrs.permissions is None else attrs.permissions - return open(_to_local_path(self.map_path(path)), mode, buffering=0, - opener=lambda path, _: os.open(path, flags, perms)) + + file_obj = open(_to_local_path(self.map_path(path)), mode, buffering=0, + opener=lambda path, _: os.open(path, flags, perms)) + + if mode[0] in 'wx': + make_sparse_file(file_obj) + + return file_obj def open56(self, path: bytes, desired_access: int, flags: int, attrs: SFTPAttrs) -> MaybeAwait[object]: @@ -6643,8 +7395,14 @@ def open56(self, path: bytes, desired_access: int, flags: int, pass perms = 0o666 if attrs.permissions is None else attrs.permissions - return open(_to_local_path(self.map_path(path)), mode, buffering=0, - opener=lambda path, _: os.open(path, open_flags, perms)) + + file_obj = open(_to_local_path(self.map_path(path)), mode, buffering=0, + opener=lambda path, _: os.open(path, open_flags, perms)) + + if mode[0] in 'wx': + make_sparse_file(file_obj) + + return file_obj def close(self, file_obj: object) -> MaybeAwait[None]: """Close an open file or directory @@ -6769,6 +7527,30 @@ def setstat(self, path: bytes, attrs: SFTPAttrs) -> MaybeAwait[None]: """ _setstat(_to_local_path(self.map_path(path)), attrs) + + return None + + def lsetstat(self, path: bytes, attrs: SFTPAttrs) -> MaybeAwait[None]: + """Set attributes of a file, directory, or symlink + + This method sets attributes of a file, directory, or symlink. + A subset of the fields in `attrs` can be initialized and only + those attributes should be changed. + + :param path: + The path of the remote file or directory to set attributes for + :param attrs: + File attributes to set + :type path: `bytes` + :type attrs: :class:`SFTPAttrs` + + :raises: :exc:`SFTPError` to return an error to the client + + """ + + _setstat(_to_local_path(self.map_path(path)), attrs, + follow_symlinks=False) + return None def fsetstat(self, file_obj: object, attrs: SFTPAttrs) -> MaybeAwait[None]: @@ -6795,30 +7577,59 @@ def fsetstat(self, file_obj: object, attrs: SFTPAttrs) -> MaybeAwait[None]: return None - def listdir(self, path: bytes) -> \ - MaybeAwait[Sequence[Union[bytes, SFTPName]]]: - """List the contents of a directory + async def scandir(self, path: bytes) -> AsyncIterator[SFTPName]: + """Return names and attributes of the files in a directory + + This function returns an async iterator of :class:`SFTPName` + entries corresponding to files in the requested directory. :param path: - The path of the directory to open + The path of the directory to scan :type path: `bytes` - :returns: A list of names of files in the directory or - :class:`SFTPName` objects containing file names - and attributes + :returns: An async iterator of :class:`SFTPName` :raises: :exc:`SFTPError` to return an error to the client """ - files = os.listdir(_to_local_path(self.map_path(path))) + if hasattr(self, 'listdir'): + # Support backward compatibility with older AsyncSSH versions + # which allowed listdir() to be overridden, returning a list + # of either :class:`SFTPName` objects or plain filenames, in + # which case :meth:`lstat` is called to retrieve attribute + # information. - if sys.platform == 'win32': # pragma: no cover - files = [os.fsencode(f) for f in files] + # pylint: disable=no-member + listdir_result = self.listdir(path) # type: ignore + + if inspect.isawaitable(listdir_result): + listdir_result = await cast( + Awaitable[Sequence[Union[bytes, SFTPName]]], + listdir_result) + + listdir_result: Sequence[Union[bytes, SFTPName]] + + for name in listdir_result: + if isinstance(name, bytes): + yield await self._to_sftpname(path, name) + else: + yield name + else: + for name in (b'.', b'..'): + yield await self._to_sftpname(path, name) + + with os.scandir(_to_local_path(self.map_path(path))) as entries: + for entry in entries: + filename = entry.name - files: List[bytes] + if sys.platform == 'win32': # pragma: no cover + filename = os.fsencode(filename) - return [b'.', b'..'] + files + attrs = SFTPAttrs.from_local( + entry.stat(follow_symlinks=False)) + + yield SFTPName(filename, attrs=attrs) def remove(self, path: bytes) -> MaybeAwait[None]: """Remove a file or symbolic link @@ -6948,6 +7759,14 @@ def readlink(self, path: bytes) -> MaybeAwait[bytes]: """ path = os.readlink(_to_local_path(self.map_path(path))) + + if sys.platform == 'win32' and \ + path.startswith('\\\\?\\'): # pragma: no cover + path = path[4:] + + if self._chroot: + path = os.path.realpath(path) + return self.reverse_map_path(_from_local_path(path)) def symlink(self, oldpath: bytes, newpath: bytes) -> MaybeAwait[None]: @@ -7103,7 +7922,7 @@ class LocalFile: def __init__(self, file: _SFTPFileObj): self._file = file - async def __aenter__(self) -> 'LocalFile': # pragma: no cover + async def __aenter__(self) -> Self: # pragma: no cover """Allow LocalFile to be used as an async context manager""" return self @@ -7117,6 +7936,12 @@ async def __aexit__(self, _exc_type: Optional[Type[BaseException]], await self.close() return False + def request_ranges(self, offset: int, length: int) -> \ + AsyncIterator[Tuple[int, int]]: + """Return data ranges containing data in a local file""" + + return _request_ranges(self._file, offset, length) + async def read(self, size: int, offset: int) -> bytes: """Read data from the local file""" @@ -7138,6 +7963,8 @@ async def close(self) -> None: class LocalFS: """An async wrapper around local filesystem access""" + limits = SFTPLimits(0, MAX_SFTP_READ_LEN, MAX_SFTP_WRITE_LEN, 0) + @staticmethod def basename(path: bytes) -> bytes: """Return the final component of a local file path""" @@ -7167,20 +7994,18 @@ def compose_path(self, path: bytes, return posixpath.join(parent, path) if parent else path - async def stat(self, path: bytes) -> 'SFTPAttrs': - """Get attributes of a local file or directory, following symlinks""" - - return SFTPAttrs.from_local(os.stat(_to_local_path(path))) - - async def lstat(self, path: bytes) -> 'SFTPAttrs': + async def stat(self, path: bytes, *, + follow_symlinks: bool = True) -> 'SFTPAttrs': """Get attributes of a local file, directory, or symlink""" - return SFTPAttrs.from_local(os.lstat(_to_local_path(path))) + return SFTPAttrs.from_local(os.stat(_to_local_path(path), + follow_symlinks=follow_symlinks)) - async def setstat(self, path: bytes, attrs: 'SFTPAttrs') -> None: - """Set attributes of a local file or directory""" + async def setstat(self, path: bytes, attrs: 'SFTPAttrs', *, + follow_symlinks: bool = True) -> None: + """Set attributes of a local file, directory, or symlink""" - _setstat(_to_local_path(path), attrs) + _setstat(_to_local_path(path), attrs, follow_symlinks=follow_symlinks) async def exists(self, path: bytes) -> bool: """Return if the local path exists and isn't a broken symbolic link""" @@ -7192,15 +8017,18 @@ async def isdir(self, path: bytes) -> bool: return os.path.isdir(_to_local_path(path)) - async def listdir(self, path: bytes) -> Sequence[bytes]: - """Read the names of the files in a local directory""" + async def scandir(self, path: bytes) -> AsyncIterator[SFTPName]: + """Return names and attributes of the files in a local directory""" - files = os.listdir(_to_local_path(path)) + with os.scandir(_to_local_path(path)) as entries: + for entry in entries: + filename = entry.name - if sys.platform == 'win32': # pragma: no cover - files = [os.fsencode(f) for f in files] + if sys.platform == 'win32': # pragma: no cover + filename = os.fsencode(filename) - return files + attrs = SFTPAttrs.from_local(entry.stat(follow_symlinks=False)) + yield SFTPName(filename, attrs=attrs) async def mkdir(self, path: bytes) -> None: """Create a local directory with the specified attributes""" @@ -7210,7 +8038,13 @@ async def mkdir(self, path: bytes) -> None: async def readlink(self, path: bytes) -> bytes: """Return the target of a local symbolic link""" - return _from_local_path(os.readlink(_to_local_path(path))) + path = os.readlink(_to_local_path(path)) + + if sys.platform == 'win32' and \ + path.startswith('\\\\?\\'): # pragma: no cover + path = path[4:] + + return _from_local_path(path) async def symlink(self, oldpath: bytes, newpath: bytes) -> None: """Create a local symbolic link""" @@ -7218,12 +8052,19 @@ async def symlink(self, oldpath: bytes, newpath: bytes) -> None: os.symlink(_to_local_path(oldpath), _to_local_path(newpath)) @async_context_manager - async def open(self, path: bytes, mode: str) -> LocalFile: + async def open(self, path: bytes, mode: str, + block_size: int = -1) -> LocalFile: """Open a local file""" # pylint: disable=unused-argument - return LocalFile(open(_to_local_path(path), mode)) + file_obj = open(_to_local_path(path), mode) + + if mode[0] in 'wx': + make_sparse_file(file_obj) + + return LocalFile(file_obj) + local_fs = LocalFS() @@ -7235,7 +8076,7 @@ def __init__(self, server: SFTPServer, file_obj: object): self._server = server self._file_obj = file_obj - async def __aenter__(self) -> 'SFTPServerFile': # pragma: no cover + async def __aenter__(self) -> Self: # pragma: no cover """Allow SFTPServerFile to be used as an async context manager""" return self @@ -7279,7 +8120,6 @@ async def close(self) -> None: result = self._server.close(self._file_obj) if inspect.isawaitable(result): - assert result is not None await result @@ -7316,7 +8156,6 @@ async def setstat(self, path: bytes, attrs: SFTPAttrs) -> None: result = self._server.setstat(path, attrs) if inspect.isawaitable(result): - assert result is not None await result async def _type(self, path: bytes) -> int: @@ -7335,24 +8174,17 @@ async def _type(self, path: bytes) -> int: async def exists(self, path: bytes) -> bool: """Return if a path exists""" - return (await self._type(path)) != FILEXFER_TYPE_UNKNOWN + return await self._type(path) != FILEXFER_TYPE_UNKNOWN async def isdir(self, path: bytes) -> bool: """Return if the path refers to a directory""" - return (await self._type(path)) == FILEXFER_TYPE_DIRECTORY - - async def listdir(self, path: bytes) -> Sequence[bytes]: - """List the contents of a directory""" + return await self._type(path) == FILEXFER_TYPE_DIRECTORY - files = self._server.listdir(path) + def scandir(self, path: bytes) -> AsyncIterator[SFTPName]: + """Return names and attributes of the files in a directory""" - if inspect.isawaitable(files): - files = await cast(Awaitable[Sequence[bytes]], files) - - files: Sequence[bytes] - - return files + return self._server.scandir(path) async def mkdir(self, path: bytes) -> None: """Create a directory""" @@ -7360,7 +8192,6 @@ async def mkdir(self, path: bytes) -> None: result = self._server.mkdir(path, SFTPAttrs()) if inspect.isawaitable(result): - assert result is not None await result @async_context_manager @@ -7392,16 +8223,32 @@ async def start_sftp_client(conn: 'SSHClientConnection', conn.create_task(handler.recv_packets(), handler.logger) + await handler.request_limits() + return SFTPClient(handler, path_encoding, path_errors) -def run_sftp_server(sftp_server: SFTPServer, reader: 'SSHReader[bytes]', - writer: 'SSHWriter[bytes]', - sftp_version: int) -> Awaitable[None]: - """Return a handler for an SFTP server session""" +async def _sftp_handler(sftp_server: MaybeAwait[SFTPServer], + reader: 'SSHReader[bytes]', + writer: 'SSHWriter[bytes]', + sftp_version: int) -> None: + """Run an SFTP server to handle this request""" + + if inspect.isawaitable(sftp_server): + sftp_server = await sftp_server + + sftp_server: SFTPServer handler = SFTPServerHandler(sftp_server, reader, writer, sftp_version) - handler.logger.info('Starting SFTP server') + await handler.run() + + +def run_sftp_server(sftp_server: MaybeAwait[SFTPServer], + reader: 'SSHReader[bytes]', writer: 'SSHWriter[bytes]', + sftp_version: int) -> Awaitable[None]: + """Return a handler for an SFTP server session""" + + reader.logger.info('Starting SFTP server') - return handler.run() + return _sftp_handler(sftp_server, reader, writer, sftp_version) diff --git a/asyncssh/sk.py b/asyncssh/sk.py index 969bff9..ca5aef7 100644 --- a/asyncssh/sk.py +++ b/asyncssh/sk.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2022 by Ron Frederick and others. +# Copyright (c) 2019-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -20,6 +20,8 @@ """U2F security key handler""" +from base64 import urlsafe_b64encode +import ctypes from hashlib import sha256 import hmac import time @@ -54,6 +56,12 @@ def _decode_public_key(alg: int, public_key: Mapping[int, object]) -> bytes: return b'\x04' + result + cast(bytes, public_key[-3]) +def _verify_rp_id(_rp_id: str, _origin: str): + """Allow any relying party name -- SSH encodes the application here""" + + return True + + def _ctap1_poll(poll_interval: float, func: Callable[..., _PollResult], *args: object) -> _PollResult: """Poll until a CTAP1 response is received""" @@ -69,7 +77,7 @@ def _ctap1_poll(poll_interval: float, func: Callable[..., _PollResult], def _ctap1_enroll(dev: 'CtapHidDevice', alg: int, - application: bytes) -> Tuple[bytes, bytes]: + application: str) -> Tuple[bytes, bytes]: """Enroll a new security key using CTAP version 1""" ctap1 = Ctap1(dev) @@ -77,21 +85,20 @@ def _ctap1_enroll(dev: 'CtapHidDevice', alg: int, if alg != SSH_SK_ECDSA: raise ValueError('Unsupported algorithm') - app_hash = sha256(application).digest() + app_hash = sha256(application.encode('utf-8')).digest() registration = _ctap1_poll(_CTAP1_POLL_INTERVAL, ctap1.register, _dummy_hash, app_hash) return registration.public_key, registration.key_handle -def _ctap2_enroll(dev: 'CtapHidDevice', alg: int, application: bytes, +def _ctap2_enroll(dev: 'CtapHidDevice', alg: int, application: str, user: str, pin: Optional[str], resident: bool) -> Tuple[bytes, bytes]: """Enroll a new security key using CTAP version 2""" ctap2 = Ctap2(dev) - application = application.decode('utf-8') rp = {'id': application, 'name': application} user_cred = {'id': user.encode('utf-8'), 'name': user} key_params = [{'type': 'public-key', 'alg': alg}] @@ -118,13 +125,31 @@ def _ctap2_enroll(dev: 'CtapHidDevice', alg: int, application: bytes, return _decode_public_key(alg, cdata.public_key), cdata.credential_id -def _ctap1_sign(dev: 'CtapHidDevice', message_hash: bytes, application: bytes, +def _win_enroll(alg: int, application: str, user: str) -> Tuple[bytes, bytes]: + """Enroll a new security key using Windows WebAuthn API""" + + client = WindowsClient(application, verify=_verify_rp_id) + + rp = {'id': application, 'name': application} + user_cred = {'id': user.encode('utf-8'), 'name': user} + key_params = [{'type': 'public-key', 'alg': alg}] + options = {'rp': rp, 'user': user_cred, 'challenge': b'', + 'pubKeyCredParams': key_params} + + result = client.make_credential(options) + cdata = result.attestation_object.auth_data.credential_data + + # pylint: disable=no-member + return _decode_public_key(alg, cdata.public_key), cdata.credential_id + + +def _ctap1_sign(dev: 'CtapHidDevice', message_hash: bytes, application: str, key_handle: bytes) -> Tuple[int, int, bytes]: """Sign a message with a security key using CTAP version 1""" ctap1 = Ctap1(dev) - app_hash = sha256(application).digest() + app_hash = sha256(application.encode('utf-8')).digest() auth_response = _ctap1_poll(_CTAP1_POLL_INTERVAL, ctap1.authenticate, message_hash, app_hash, key_handle) @@ -137,16 +162,20 @@ def _ctap1_sign(dev: 'CtapHidDevice', message_hash: bytes, application: bytes, def _ctap2_sign(dev: 'CtapHidDevice', message_hash: bytes, - application: bytes, key_handle: bytes, + application: str, key_handle: bytes, touch_required: bool) -> Tuple[int, int, bytes]: """Sign a message with a security key using CTAP version 2""" ctap2 = Ctap2(dev) - application = application.decode('utf-8') allow_creds = [{'type': 'public-key', 'id': key_handle}] options = {'up': touch_required} + # See if key handle exists before requiring touch + if touch_required: + ctap2.get_assertions(application, message_hash, allow_creds, + options={'up': False}) + assertion = ctap2.get_assertions(application, message_hash, allow_creds, options=options)[0] @@ -155,10 +184,38 @@ def _ctap2_sign(dev: 'CtapHidDevice', message_hash: bytes, return auth_data.flags, auth_data.counter, assertion.signature -def sk_enroll(alg: int, application: bytes, user: str, - pin: Optional[str], resident: bool) -> Tuple[bytes, bytes]: +def _win_sign(data: bytes, application: str, + key_handle: bytes) -> Tuple[int, int, bytes, bytes]: + """Sign a message with a security key using Windows WebAuthn API""" + + client = WindowsClient(application, verify=_verify_rp_id) + + creds = [{'type': 'public-key', 'id': key_handle}] + options = {'challenge': data, 'rpId': application, + 'allowCredentials': creds} + + result = client.get_assertion(options).get_response(0) + auth_data = result.authenticator_data + + return auth_data.flags, auth_data.counter, \ + result.signature, bytes(result.client_data) + + +def sk_webauthn_prefix(data: bytes, application: str) -> bytes: + """Calculate a WebAuthn request prefix""" + + return b'{"type":"webauthn.get","challenge":"' + \ + urlsafe_b64encode(data).rstrip(b'=') + b'","origin":"' + \ + application.encode('utf-8') + b'"' + + +def sk_enroll(alg: int, application: str, user: str, pin: Optional[str], + resident: bool) -> Tuple[bytes, bytes]: """Enroll a new security key""" + if sk_use_webauthn: + return _win_enroll(alg, application, user) + try: dev = next(CtapHidDevice.list_devices()) except StopIteration: @@ -173,7 +230,7 @@ def sk_enroll(alg: int, application: bytes, user: str, raise ValueError('Invalid PIN') from None else: raise ValueError(str(exc)) from None - except ValueError as exc: + except ValueError: try: return _ctap1_enroll(dev, alg, application) except ApduError as exc: @@ -182,22 +239,35 @@ def sk_enroll(alg: int, application: bytes, user: str, dev.close() -def sk_sign(message_hash: bytes, application: bytes, key_handle: bytes, - flags: int) -> Tuple[int, int, bytes]: +def sk_sign(data: bytes, application: str, key_handle: bytes, flags: int, + is_webauthn: bool = False) -> Tuple[int, int, bytes, bytes]: """Sign a message with a security key""" touch_required = bool(flags & SSH_SK_USER_PRESENCE_REQD) + if is_webauthn and sk_use_webauthn: + return _win_sign(data, application, key_handle) + + if is_webauthn: + data = sk_webauthn_prefix(data, application) + b'}' + + message_hash = sha256(data).digest() + for dev in CtapHidDevice.list_devices(): try: - return _ctap2_sign(dev, message_hash, application, - key_handle, touch_required) + flags, counter, sig = _ctap2_sign(dev, message_hash, application, + key_handle, touch_required) + + return flags, counter, sig, data except CtapError as exc: if exc.code != CtapError.ERR.NO_CREDENTIALS: raise ValueError(str(exc)) from None except ValueError: try: - return _ctap1_sign(dev, message_hash, application, key_handle) + flags, counter, sig = _ctap1_sign(dev, message_hash, + application, key_handle) + + return flags, counter, sig, data except ApduError as exc: if exc.code != APDU.WRONG_DATA: raise ValueError(str(exc)) from None @@ -207,11 +277,11 @@ def sk_sign(message_hash: bytes, application: bytes, key_handle: bytes, raise ValueError('Security key credential not found') -def sk_get_resident(application: bytes, user: Optional[str], +def sk_get_resident(application: str, user: Optional[str], pin: str) -> Sequence[_SKResidentKey]: """Get keys resident on a security key""" - app_hash = sha256(application).digest() + app_hash = sha256(application.encode('utf-8')).digest() result: List[_SKResidentKey] = [] for dev in CtapHidDevice.list_devices(): @@ -257,15 +327,21 @@ def sk_get_resident(application: bytes, user: Optional[str], try: - from fido2.hid import CtapHidDevice + from fido2.client import WindowsClient from fido2.ctap import CtapError from fido2.ctap1 import Ctap1, APDU, ApduError from fido2.ctap2 import Ctap2, ClientPin, PinProtocolV1 from fido2.ctap2 import CredentialManagement + from fido2.hid import CtapHidDevice sk_available = True + + sk_use_webauthn = WindowsClient.is_available() and \ + hasattr(ctypes, 'windll') and \ + not ctypes.windll.shell32.IsUserAnAdmin() except (ImportError, OSError, AttributeError): # pragma: no cover sk_available = False + sk_use_webauthn = False def _sk_not_available(*args: object, **kwargs: object) -> NoReturn: """Report that security key support is unavailable""" diff --git a/asyncssh/sk_ecdsa.py b/asyncssh/sk_ecdsa.py index b8c3854..95790ba 100644 --- a/asyncssh/sk_ecdsa.py +++ b/asyncssh/sk_ecdsa.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021 by Ron Frederick and others. +# Copyright (c) 2019-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -29,30 +29,34 @@ from .public_key import KeyExportError, SSHKey, SSHOpenSSHCertificateV01 from .public_key import register_public_key_alg, register_certificate_alg from .public_key import register_sk_alg -from .sk import SSH_SK_ECDSA, SSH_SK_USER_PRESENCE_REQD, sk_enroll, sk_sign +from .sk import SSH_SK_ECDSA, SSH_SK_USER_PRESENCE_REQD +from .sk import sk_enroll, sk_sign, sk_webauthn_prefix, sk_use_webauthn -_PrivateKeyArgs = Tuple[bytes, bytes, bytes, int, bytes, bytes] -_PublicKeyArgs = Tuple[bytes, bytes, bytes] +_PrivateKeyArgs = Tuple[bytes, bytes, str, int, bytes, bytes] +_PublicKeyArgs = Tuple[bytes, bytes, str] class _SKECDSAKey(SSHKey): - """Handler for elliptic curve public key encryption""" + """Handler for U2F ECDSA public key encryption""" _key: ECDSAPublicKey use_executor = True - def __init__(self, curve_id: bytes, public_value: bytes, - application: bytes, flags: int = 0, - key_handle: bytes = None, reserved: bytes = b''): + def __init__(self, curve_id: bytes, public_value: bytes, application: str, + flags: int = 0, key_handle: Optional[bytes] = None, + reserved: bytes = b''): super().__init__(ECDSAPublicKey.construct(curve_id, public_value)) self.algorithm = b'sk-ecdsa-sha2-' + curve_id + b'@openssh.com' - self.sig_algorithms = (self.algorithm,) + self.sig_algorithms = (self.algorithm, b'webauthn-' + self.algorithm) self.all_sig_algorithms = set(self.sig_algorithms) + self.use_webauthn = sk_use_webauthn + self._application = application + self._app_hash = sha256(application.encode('utf-8')).digest() self._flags = flags self._key_handle = key_handle self._reserved = reserved @@ -83,7 +87,6 @@ def generate(cls, algorithm: bytes, *, # type: ignore # pylint: disable=arguments-differ - application = application.encode('utf-8') flags = SSH_SK_USER_PRESENCE_REQD if touch_required else 0 public_value, key_handle = sk_enroll(SSH_SK_ECDSA, application, @@ -117,7 +120,7 @@ def decode_ssh_private(cls, packet: SSHPacket) -> _PrivateKeyArgs: curve_id = packet.get_string() public_value = packet.get_string() - application = packet.get_string() + application = packet.get_string().decode('utf-8') flags = packet.get_byte() key_handle = packet.get_string() reserved = packet.get_string() @@ -130,7 +133,7 @@ def decode_ssh_public(cls, packet: SSHPacket) -> _PublicKeyArgs: curve_id = packet.get_string() public_value = packet.get_string() - application = packet.get_string() + application = packet.get_string().decode('utf-8') return curve_id, public_value, application @@ -164,27 +167,46 @@ def encode_agent_cert_private(self) -> bytes: def sign_ssh(self, data: bytes, sig_algorithm: bytes) -> bytes: """Compute an SSH-encoded signature of the specified data""" - # pylint: disable=unused-argument - if self._key_handle is None: raise ValueError('Key handle needed for signing') - flags, counter, sig = sk_sign(sha256(data).digest(), self._application, - self._key_handle, self._flags) + is_webauthn = sig_algorithm.startswith(b'webauthn') + + flags, counter, sig, client_data = sk_sign(data, self._application, + self._key_handle, + self._flags, is_webauthn) r, s = cast(Tuple[int, int], der_decode(sig)) - return String(MPInt(r) + MPInt(s)) + Byte(flags) + UInt32(counter) + sig = String(MPInt(r) + MPInt(s)) + Byte(flags) + UInt32(counter) + + if is_webauthn: + sig += String(self._application) + String(client_data) + String('') + + return sig def verify_ssh(self, data: bytes, sig_algorithm: bytes, packet: SSHPacket) -> bool: """Verify an SSH-encoded signature of the specified data""" - # pylint: disable=unused-argument + is_webauthn = sig_algorithm.startswith(b'webauthn') sig = packet.get_string() flags = packet.get_byte() counter = packet.get_uint32() + + if is_webauthn: + _ = packet.get_string() # origin + client_data = packet.get_string() + _ = packet.get_string() # extensions + + prefix = sk_webauthn_prefix(data, self._application) + + if not client_data.startswith(prefix): + return False + + data = client_data + packet.check_end() if self._touch_required and not flags & SSH_SK_USER_PRESENCE_REQD: @@ -197,9 +219,9 @@ def verify_ssh(self, data: bytes, sig_algorithm: bytes, sig = der_encode((r, s)) - return self._key.verify(sha256(self._application).digest() + - Byte(flags) + UInt32(counter) + - sha256(data).digest(), sig, 'sha256') + return self._key.verify(self._app_hash + Byte(flags) + + UInt32(counter) + sha256(data).digest(), + sig, 'sha256') _algorithm = b'sk-ecdsa-sha2-nistp256@openssh.com' @@ -207,7 +229,8 @@ def verify_ssh(self, data: bytes, sig_algorithm: bytes, register_sk_alg(SSH_SK_ECDSA, _SKECDSAKey, b'nistp256') -register_public_key_alg(_algorithm, _SKECDSAKey, True, (_algorithm,)) +register_public_key_alg(_algorithm, _SKECDSAKey, True, + (_algorithm, b'webauthn-' + _algorithm)) register_certificate_alg(1, _algorithm, _cert_algorithm, _SKECDSAKey, SSHOpenSSHCertificateV01, True) diff --git a/asyncssh/sk_eddsa.py b/asyncssh/sk_eddsa.py index 94ea4d7..c25aad6 100644 --- a/asyncssh/sk_eddsa.py +++ b/asyncssh/sk_eddsa.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021 by Ron Frederick and others. +# Copyright (c) 2019-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -31,8 +31,8 @@ from .sk import SSH_SK_ED25519, SSH_SK_USER_PRESENCE_REQD, sk_enroll, sk_sign -_PrivateKeyArgs = Tuple[bytes, bytes, int, bytes, bytes] -_PublicKeyArgs = Tuple[bytes, bytes] +_PrivateKeyArgs = Tuple[bytes, str, int, bytes, bytes] +_PublicKeyArgs = Tuple[bytes, str] class _SKEd25519Key(SSHKey): @@ -45,12 +45,13 @@ class _SKEd25519Key(SSHKey): all_sig_algorithms = set(sig_algorithms) use_executor = True - def __init__(self, public_value: bytes, application: bytes, - flags: int = 0, key_handle: bytes = None, + def __init__(self, public_value: bytes, application: str, + flags: int = 0, key_handle: Optional[bytes] = None, reserved: bytes = b''): super().__init__(EdDSAPublicKey.construct(b'ed25519', public_value)) self._application = application + self._app_hash = sha256(application.encode('utf-8')).digest() self._flags = flags self._key_handle = key_handle self._reserved = reserved @@ -79,7 +80,6 @@ def generate(cls, algorithm: bytes, *, # type: ignore # pylint: disable=arguments-differ - application = application.encode('utf-8') flags = SSH_SK_USER_PRESENCE_REQD if touch_required else 0 public_value, key_handle = sk_enroll(SSH_SK_ED25519, application, @@ -109,7 +109,7 @@ def decode_ssh_private(cls, packet: SSHPacket) -> _PrivateKeyArgs: """Decode an SSH format U2F Ed25519 private key""" public_value = packet.get_string() - application = packet.get_string() + application = packet.get_string().decode('utf-8') flags = packet.get_byte() key_handle = packet.get_string() reserved = packet.get_string() @@ -121,7 +121,7 @@ def decode_ssh_public(cls, packet: SSHPacket) -> _PublicKeyArgs: """Decode an SSH format U2F Ed25519 public key""" public_value = packet.get_string() - application = packet.get_string() + application = packet.get_string().decode('utf-8') return public_value, application @@ -154,8 +154,8 @@ def sign_ssh(self, data: bytes, sig_algorithm: bytes) -> bytes: if self._key_handle is None: raise ValueError('Key handle needed for signing') - flags, counter, sig = sk_sign(sha256(data).digest(), self._application, - self._key_handle, self._flags) + flags, counter, sig, _ = sk_sign(data, self._application, + self._key_handle, self._flags) return String(sig) + Byte(flags) + UInt32(counter) @@ -173,9 +173,8 @@ def verify_ssh(self, data: bytes, sig_algorithm: bytes, if self._touch_required and not flags & SSH_SK_USER_PRESENCE_REQD: return False - return self._key.verify(sha256(self._application).digest() + - Byte(flags) + UInt32(counter) + - sha256(data).digest(), sig) + return self._key.verify(self._app_hash + Byte(flags) + + UInt32(counter) + sha256(data).digest(), sig) if ed25519_available: # pragma: no branch diff --git a/asyncssh/socks.py b/asyncssh/socks.py index cb818d8..960cda6 100644 --- a/asyncssh/socks.py +++ b/asyncssh/socks.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2021 by Ron Frederick and others. +# Copyright (c) 2018-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Callable, Optional from .forward import SSHForwarderCoro, SSHLocalForwarder +from .session import DataType if TYPE_CHECKING: @@ -212,13 +213,13 @@ def _recv_socks5_port(self, data: bytes) -> None: self._send_socks5_ok() self._connect() - def data_received(self, data: bytes, datatype: int = None) -> None: + def data_received(self, data: bytes, datatype: DataType = None) -> None: """Handle incoming data from the SOCKS client""" if self._recv_handler: self._inpbuf += data - while self._recv_handler: + while self._recv_handler: # type: ignore[truthy-function] if self._bytes_needed < 0: idx = self._inpbuf.find(b'\0') if idx >= 0: diff --git a/asyncssh/stream.py b/asyncssh/stream.py index a8e95a4..247a75d 100644 --- a/asyncssh/stream.py +++ b/asyncssh/stream.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -24,15 +24,15 @@ import inspect import re from typing import TYPE_CHECKING, Any, AnyStr, AsyncIterator -from typing import Callable, Dict, Generic, Iterable -from typing import List, Optional, Set, Tuple, Union, cast +from typing import Callable, Dict, Generic, Iterable, List +from typing import Optional, Pattern, Set, Tuple, Union, cast from .constants import EXTENDED_DATA_STDERR from .logging import SSHLogger from .misc import MaybeAwait, BreakReceived, SignalReceived from .misc import SoftEOFReceived, TerminalSizeChanged from .session import DataType, SSHClientSession, SSHServerSession -from .session import SSHTCPSession, SSHUNIXSession +from .session import SSHTCPSession, SSHUNIXSession, SSHTunTapSession from .sftp import SFTPServer, run_sftp_server from .scp import run_scp_server @@ -62,7 +62,7 @@ 'SSHWriter'], MaybeAwait[None]] _OptServerSessionFactory = Optional[SSHServerSessionFactory] -SFTPServerFactory = Callable[['SSHChannel[bytes]'], SFTPServer] +SFTPServerFactory = Callable[['SSHChannel[bytes]'], MaybeAwait[SFTPServer]] _OptSFTPServerFactory = Optional[SFTPServerFactory] @@ -81,8 +81,8 @@ def __init__(self, session: 'SSHStreamSession[AnyStr]', async def __aiter__(self) -> AsyncIterator[AnyStr]: """Allow SSHReader to be an async iterator""" - while not self.at_eof(): - yield await self.readline() + async for result in self._session.aiter(self._datatype): + yield result @property def channel(self) -> 'SSHChannel[AnyStr]': @@ -154,7 +154,7 @@ async def read(self, n: int = -1) -> AnyStr: """ - return await self._session.read(n, self._datatype, exact=False) + return await self._session.read(self._datatype, n, exact=False) async def readline(self) -> AnyStr: """Read one line from the stream @@ -175,22 +175,37 @@ async def readline(self) -> AnyStr: """ - try: - return await self.readuntil(_NEWLINE) - except asyncio.IncompleteReadError as exc: - return cast(AnyStr, exc.partial) + return await self._session.readline(self._datatype) - async def readuntil(self, separator: object) -> AnyStr: + async def readuntil(self, separator: object, + max_separator_len = 0) -> AnyStr: """Read data from the stream until `separator` is seen This method is a coroutine which reads from the stream until the requested separator is seen. If a match is found, the returned data will include the separator at the end. - The separator argument can be either a single `bytes` or - `str` value or a sequence of multiple values to match - against, returning data as soon as any of the separators - are found in the stream. + The `separator` argument can be a single `bytes` or `str` + value, a sequence of multiple `bytes` or `str` values, + or a compiled regex (`re.Pattern`) to match against, + returning data as soon as a matching separator is found + in the stream. + + When passing a regex pattern as the separator, the + `max_separator_len` argument should be set to the + maximum length of an expected separator match. This + can greatly improve performance, by minimizing how far + back into the stream must be searched for a match. + When passing literal separators to match against, the + max separator length will be set automatically. + + .. note:: For best results, a separator regex should + both begin and end with data which is as + unique as possible, and should not start or + end with optional or repeated elements. + Otherwise, you run the risk of failing to + match parts of a separator when it is split + across multiple reads. If EOF or a signal is received before a match occurs, an :exc:`IncompleteReadError ` @@ -202,7 +217,8 @@ async def readuntil(self, separator: object) -> AnyStr: """ - return await self._session.readuntil(separator, self._datatype) + return await self._session.readuntil(separator, self._datatype, + max_separator_len) async def readexactly(self, n: int) -> AnyStr: """Read an exact amount of data from the stream @@ -220,7 +236,7 @@ async def readexactly(self, n: int) -> AnyStr: """ - return await self._session.read(n, self._datatype, exact=True) + return await self._session.read(self._datatype, n, exact=True) def at_eof(self) -> bool: """Return whether the stream is at EOF @@ -375,6 +391,12 @@ def __init__(self) -> None: self._read_waiters: _ReadWaiters = {None: None} self._drain_waiters: _DrainWaiters = {None: set()} + async def aiter(self, datatype: DataType) -> AsyncIterator[AnyStr]: + """Allow SSHReader to be an async iterator""" + + while not self.at_eof(datatype): + yield await self.readline(datatype) + async def _block_read(self, datatype: DataType) -> None: """Wait for more data to arrive on the stream""" @@ -504,17 +526,19 @@ def resume_writing(self) -> None: for datatype in self._drain_waiters: self._unblock_drain(datatype) - async def read(self, n: int, datatype: DataType, exact: bool) -> AnyStr: + async def read(self, datatype: DataType, n: int, exact: bool) -> AnyStr: """Read data from the channel""" recv_buf = self._recv_buf[datatype] data: List[AnyStr] = [] + break_read = False async with self._read_locks[datatype]: while True: while recv_buf and n != 0: if isinstance(recv_buf[0], Exception): if data: + break_read = True break else: exc = cast(Exception, recv_buf.pop(0)) @@ -542,7 +566,8 @@ async def read(self, n: int, datatype: DataType, exact: bool) -> AnyStr: continue if n == 0 or (n > 0 and data and not exact) or \ - (n < 0 and recv_buf) or self._eof_received: + (n < 0 and recv_buf) or \ + self._eof_received or break_read: break await self._block_read(datatype) @@ -555,7 +580,16 @@ async def read(self, n: int, datatype: DataType, exact: bool) -> AnyStr: return result - async def readuntil(self, separator: object, datatype: DataType) -> AnyStr: + async def readline(self, datatype: DataType) -> AnyStr: + """Read one line from the stream""" + + try: + return await self.readuntil(_NEWLINE, datatype) + except asyncio.IncompleteReadError as exc: + return cast(AnyStr, exc.partial) + + async def readuntil(self, separator: object, datatype: DataType, + max_separator_len = 0) -> AnyStr: """Read data from the channel until a separator is seen""" if not separator: @@ -567,16 +601,20 @@ async def readuntil(self, separator: object, datatype: DataType) -> AnyStr: if separator is _NEWLINE: seplen = 1 separators = cast(AnyStr, '\n' if self._encoding else b'\n') + pat = re.compile(separators) elif isinstance(separator, (bytes, str)): seplen = len(separator) - separators = cast(AnyStr, separator) + pat = re.compile(re.escape(cast(AnyStr, separator))) + elif isinstance(separator, Pattern): + seplen = max_separator_len + pat = cast(Pattern[AnyStr], separator) else: bar = cast(AnyStr, '|' if self._encoding else b'|') seplist = list(cast(Iterable[AnyStr], separator)) seplen = max(len(sep) for sep in seplist) separators = bar.join(re.escape(sep) for sep in seplist) + pat = re.compile(separators) - pat = re.compile(separators) curbuf = 0 buflen = 0 @@ -599,7 +637,7 @@ async def readuntil(self, separator: object, datatype: DataType) -> AnyStr: newbuf = cast(AnyStr, recv_buf[curbuf]) buf += newbuf - start = max(buflen + 1 - seplen, 0) + start = 0 if seplen == 0 else max(buflen + 1 - seplen, 0) match = pat.search(buf, start) if match: @@ -667,7 +705,7 @@ def __init__(self, session_factory: _OptServerSessionFactory, self._sftp_version = sftp_version self._allow_scp = allow_scp and bool(sftp_factory) - def _init_sftp_server(self) -> SFTPServer: + def _init_sftp_server(self) -> MaybeAwait[SFTPServer]: """Initialize an SFTP server for this stream to use""" assert self._chan is not None @@ -733,35 +771,36 @@ def session_started(self) -> None: if inspect.isawaitable(handler): assert self._conn is not None - assert handler is not None self._conn.create_task(handler, stdin.logger) + def exception_received(self, exc: Exception) -> None: + """Handle an incoming exception on the channel""" + + self._recv_buf[None].append(exc) + self._unblock_read(None) + def break_received(self, msec: int) -> bool: """Handle an incoming break on the channel""" - self._recv_buf[None].append(BreakReceived(msec)) - self._unblock_read(None) + self.exception_received(BreakReceived(msec)) return True def signal_received(self, signal: str) -> None: """Handle an incoming signal on the channel""" - self._recv_buf[None].append(SignalReceived(signal)) - self._unblock_read(None) + self.exception_received(SignalReceived(signal)) def soft_eof_received(self) -> None: """Handle an incoming soft EOF on the channel""" - self._recv_buf[None].append(SoftEOFReceived()) - self._unblock_read(None) + self.exception_received(SoftEOFReceived()) def terminal_size_changed(self, width: int, height: int, pixwidth: int, pixheight: int) -> None: """Handle an incoming terminal size change on the channel""" - self._recv_buf[None].append(TerminalSizeChanged(width, height, - pixwidth, pixheight)) - self._unblock_read(None) + self.exception_received(TerminalSizeChanged(width, height, + pixwidth, pixheight)) class SSHSocketStreamSession(SSHStreamSession[AnyStr]): @@ -784,7 +823,6 @@ def session_started(self) -> None: if inspect.isawaitable(handler): assert self._conn is not None - assert handler is not None self._conn.create_task(handler, reader.logger) @@ -796,3 +834,34 @@ class SSHTCPStreamSession(SSHSocketStreamSession[AnyStr], class SSHUNIXStreamSession(SSHSocketStreamSession[AnyStr], SSHUNIXSession[AnyStr]): """UNIX stream session handler""" + +class SSHTunTapStreamSession(SSHSocketStreamSession[bytes], SSHTunTapSession): + """TUN/TAP stream session handler""" + + async def aiter(self, datatype: DataType) -> AsyncIterator[bytes]: + """Allow SSHReader to be an async iterator""" + + while True: + packet = await self.read(datatype) + + if packet: + yield packet + else: + break + + async def read(self, datatype: DataType, n: int = -1, + exact: bool = False) -> bytes: + """Override read to preserve TUN/TAP packet boundaries""" + + recv_buf = self._recv_buf[datatype] + + while not self._eof_received: + if recv_buf: + data = cast(bytes, recv_buf.pop(0)) + self._recv_buf_len -= len(data) + self._maybe_resume_reading() + return data + else: + await self._block_read(datatype) + + return b'' diff --git a/asyncssh/subprocess.py b/asyncssh/subprocess.py index 7655e0b..3b1c30f 100644 --- a/asyncssh/subprocess.py +++ b/asyncssh/subprocess.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021 by Ron Frederick and others. +# Copyright (c) 2019-2023 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -94,8 +94,8 @@ def get_write_buffer_size(self) -> int: return self._chan.get_write_buffer_size() - def set_write_buffer_limits(self, high: int = None, - low: int = None) -> None: + def set_write_buffer_limits(self, high: Optional[int] = None, + low: Optional[int] = None) -> None: """Set the high- and low-water limits for write flow control""" self._chan.set_write_buffer_limits(high, low) @@ -134,7 +134,7 @@ def connection_made(self, transport: 'SSHSubprocessTransport[AnyStr]') -> None: """Called when a remote process is successfully started - This method is called when a a remote process is successfully + This method is called when a remote process is successfully started. The transport parameter should be stored if needed for later use. diff --git a/asyncssh/tuntap.py b/asyncssh/tuntap.py new file mode 100644 index 0000000..63130ed --- /dev/null +++ b/asyncssh/tuntap.py @@ -0,0 +1,431 @@ +# Copyright (c) 2024 by Ron Frederick and others. +# +# This program and the accompanying materials are made available under +# the terms of the Eclipse Public License v2.0 which accompanies this +# distribution and is available at: +# +# http://www.eclipse.org/legal/epl-2.0/ +# +# This program may also be made available under the following secondary +# licenses when the conditions for such availability set forth in the +# Eclipse Public License v2.0 are satisfied: +# +# GNU General Public License, Version 2.0, or any later versions of +# that license +# +# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later +# +# Contributors: +# Ron Frederick - initial implementation, API, and documentation + +"""SSH TUN/TAP forwarding support""" + +import asyncio +import errno +import os +import socket +import struct +import sys +import threading + +from typing import Callable, Optional, Tuple, cast + +if sys.platform != 'win32': # pragma: no branch + import fcntl + + +SSH_TUN_MODE_POINTTOPOINT = 1 # layer 3 IP packets +SSH_TUN_MODE_ETHERNET = 2 # layer 2 Ethenet frames + +SSH_TUN_UNIT_ANY = 0x7fffffff # The server may choose the unit + +SSH_TUN_AF_INET = 2 # IPv4 +SSH_TUN_AF_INET6 = 24 # IPv6 + +DARWIN_CTLIOCGINFO = 0xc0644e03 +DARWIN_CTLIOCGINFO_FMT = 'I96s' + +DARWIN_SIOCGIFFLAGS = 0xc0206911 +DARWIN_SIOCSIFFLAGS = 0x80206910 + +LINUX_TUNSETIFF = 0x400454ca +LINUX_IFF_TUN = 0x1 +LINUX_IFF_TAP = 0x2 +LINUX_IFF_NO_PI = 0x1000 + +IFF_FMT = '16sH' +IFF_UP = 0x1 + + +class SSHTunTapTransport(asyncio.Transport): + """Layer 2/3 tunnel transport""" + + def __init__(self, loop: asyncio.AbstractEventLoop, interface: str): + super().__init__(extra={'interface': interface}) + + self._loop = loop + self._protocol: Optional[asyncio.Protocol] = None + + def get_protocol(self) -> asyncio.BaseProtocol: # pragma: no cover + """Get protocol object associated with transport""" + + assert self._protocol is not None + + return self._protocol + + def set_protocol(self, protocol: asyncio.BaseProtocol) -> None: + """Set protocol associated with transport""" + + self._protocol = cast(asyncio.Protocol, protocol) + + def abort(self) -> None: # pragma: no cover + """Abort this transport""" + + self.close() + + def is_reading(self) -> bool: + """Return if the transport is reading data""" + + raise NotImplementedError + + def pause_reading(self) -> None: + """Pause reading""" + + raise NotImplementedError + + def resume_reading(self) -> None: + """Resume reading""" + + raise NotImplementedError + + def can_write_eof(self) -> bool: # pragma: no cover + """This transport doesn't support writing EOF""" + + return False + + def get_write_buffer_size(self) -> int: # pragma: no cover + """This transport has no output buffer""" + + return 0 + + def get_write_buffer_limits(self) -> Tuple[int, int]: # pragma: no cover + """This transport doesn't support write buffer limits""" + + return 0, 0 + + def set_write_buffer_limits(self, high: Optional[int] = None, + low: Optional[int] = None) -> None: + """This transport doesn't support write buffer limits""" + + def write_eof(self) -> None: + """Ignore writing EOF on this transport""" + + def write(self, data: bytes) -> None: + """Write a packet""" + + raise NotImplementedError + + def is_closing(self) -> bool: # pragma: no cover + """Return if the transport is closing""" + + return False + + def close(self) -> None: + """Close this transport""" + + raise NotImplementedError + + +class SSHTunTapOSXTransport(SSHTunTapTransport): + """TunTapOSX transport""" + + def __init__(self, loop: asyncio.AbstractEventLoop, mode: int, + unit: Optional[int]): + prefix = 'tun' if mode == SSH_TUN_MODE_POINTTOPOINT else 'tap' + + if unit is None: + for i in range(16): + try: + file = open(f'/dev/{prefix}{i}', 'rb+', buffering=0) + except OSError: + pass + else: + unit = i + break + else: + raise OSError(errno.EBUSY, f'No {prefix} devices available') + else: + file = open(f'/dev/{prefix}{unit}', 'rb+', buffering=0) + + interface = f'{prefix}{unit}' + name = interface.encode() + + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + + try: + ifr = struct.pack(IFF_FMT, name, 0) + ifr = fcntl.ioctl(sock, DARWIN_SIOCGIFFLAGS, ifr) + + _, flags = struct.unpack(IFF_FMT, ifr) + flags |= IFF_UP + + ifr = struct.pack(IFF_FMT, name, flags) + fcntl.ioctl(sock, DARWIN_SIOCSIFFLAGS, ifr) + finally: + sock.close() + + super().__init__(loop, interface) + + self._file = file + self._read_thread: Optional[threading.Thread] = None + os.set_blocking(file.fileno(), True) + + def is_reading(self) -> bool: + """Return if the transport is reading data""" + + return self._read_thread is not None # pragma: no cover + + def pause_reading(self) -> None: + """Pause reading""" + + if self._read_thread: # pragma: no branch + self._read_thread.join() + self._read_thread = None + + def resume_reading(self) -> None: + """Resume reading""" + + if not self._read_thread: # pragma: no branch + self._read_thread = threading.Thread(target=self._read_loop) + self._read_thread.daemon = True + self._read_thread.start() + + def _read_loop(self) -> None: + """Loop reading packets until read is paused or done""" + + assert self._protocol is not None + + while True: + try: + data = self._file.read(65536) + except OSError as exc: + if exc.errno != errno.EBADF: # pragma: no cover + self._loop.call_soon_threadsafe( + self._protocol.connection_lost, exc) + + break + else: + self._loop.call_soon_threadsafe( + self._protocol.data_received, data) + + def write(self, data: bytes) -> None: + """Write a packet""" + + self._file.write(data) + + def close(self) -> None: + """Close this transport""" + + self._file.close() + self.pause_reading() + + +class SSHDarwinUTunTransport(SSHTunTapTransport): + """Darwin UTun transport""" + + def __init__(self, loop: asyncio.AbstractEventLoop, unit: Optional[int]): + sock = socket.socket(socket.PF_SYSTEM, socket.SOCK_DGRAM, + socket.SYSPROTO_CONTROL) + + try: + arg = struct.pack(DARWIN_CTLIOCGINFO_FMT, 0, + b'com.apple.net.utun_control') + + ctl_info = fcntl.ioctl(sock, DARWIN_CTLIOCGINFO, arg) + ctl_id, _ = struct.unpack(DARWIN_CTLIOCGINFO_FMT, ctl_info) + + unit = 0 if unit is None else unit - 15 + + sock.setblocking(False) + sock.connect((ctl_id, unit)) + + _, unit = sock.getpeername() + except OSError: + sock.close() + raise + + unit: int + + super().__init__(loop, f'utun{unit-1}') + + self._sock = sock + self._reading = False + + def is_reading(self) -> bool: # pragma: no cover + """Return if the transport is reading data""" + + return self._reading + + def pause_reading(self) -> None: + """Pause reading""" + + self._reading = False + self._loop.remove_reader(self._sock) + + def resume_reading(self) -> None: + """Resume reading""" + + self._reading = True + self._loop.add_reader(self._sock, self._read_ready) + + def _read_ready(self) -> None: + """Read available packets from the transport""" + + assert self._protocol is not None + + while True: + try: + data = self._sock.recv(65540)[4:] + except (BlockingIOError, InterruptedError): + break + except OSError as exc: # pragma: no cover + self._protocol.connection_lost(exc) + break + else: + self._protocol.data_received(data) + + def write(self, data: bytes) -> None: + """Write a packet""" + + version = data[0] >> 4 + family = socket.AF_INET if version == 4 else socket.AF_INET6 + data = family.to_bytes(4, 'big') + data + + self._sock.send(data) + + def close(self) -> None: + """Close this transport""" + + self._sock.close() + self.pause_reading() + + +class SSHLinuxTunTapTransport(SSHTunTapTransport): + """Linux TUN/TAP transport""" + + def __init__(self, loop: asyncio.AbstractEventLoop, mode: int, + unit: Optional[int]): + file = open('/dev/net/tun', 'rb+', buffering=0) + + if mode == SSH_TUN_MODE_POINTTOPOINT: + flags = LINUX_IFF_TUN | LINUX_IFF_NO_PI + prefix = 'tun' + else: + flags = LINUX_IFF_TAP | LINUX_IFF_NO_PI + prefix = 'tap' + + name = b'' if unit is None else f'{prefix}{unit}'.encode() + + ifr = struct.pack(IFF_FMT, name, flags) + + try: + ifr = fcntl.ioctl(file, LINUX_TUNSETIFF, ifr) + except OSError: + file.close() + raise + + name, _ = struct.unpack(IFF_FMT, ifr) + interface = name.strip(b'\0').decode() + + super().__init__(loop, interface) + + self._file = file + self._reading = False + os.set_blocking(file.fileno(), False) + + def is_reading(self) -> bool: # pragma: no cover + """Return if the transport is reading data""" + + return self._reading + + def pause_reading(self) -> None: + """Pause reading""" + + self._reading = False + + try: + self._loop.remove_reader(self._file) + except OSError: # pragma: no cover + pass + + def resume_reading(self) -> None: + """Resume reading""" + + self._reading = True + self._loop.add_reader(self._file, self._read_ready) + + def _read_ready(self) -> None: + """Read available packets from the transport""" + + assert self._protocol is not None + + while True: + try: + data = self._file.read(65536) + except OSError as exc: # pragma: no cover + self._protocol.connection_lost(exc) + break + else: + if data is None: + break + + self._protocol.data_received(data) + + def write(self, data: bytes) -> None: + """Write a packet""" + + self._file.write(data) + + def close(self) -> None: + """Close this transport""" + + self._file.close() + self.pause_reading() + + +def create_tuntap(protocol_factory: Callable[[], asyncio.BaseProtocol], + mode: int, unit: Optional[int]) -> \ + Tuple[SSHTunTapTransport, asyncio.BaseProtocol]: + """Create a local TUN or TAP network interface""" + + loop = asyncio.get_event_loop() + transport: Optional[SSHTunTapTransport] = None + + if sys.platform == 'darwin': + if unit is None: + try: + transport = SSHTunTapOSXTransport(loop, mode, unit) + except OSError: + if mode == SSH_TUN_MODE_POINTTOPOINT: + transport = SSHDarwinUTunTransport(loop, unit) + else: + raise + elif mode == SSH_TUN_MODE_POINTTOPOINT and unit >= 16: + transport = SSHDarwinUTunTransport(loop, unit) + else: + transport = SSHTunTapOSXTransport(loop, mode, unit) + elif sys.platform == 'linux': + transport = SSHLinuxTunTapTransport(loop, mode, unit) + else: + raise OSError(errno.EPROTONOSUPPORT, + f'TunTap not supported on {sys.platform}') + + assert transport is not None + + protocol = protocol_factory() + protocol.connection_made(transport) + + transport.set_protocol(protocol) + transport.resume_reading() + + return transport, protocol diff --git a/asyncssh/version.py b/asyncssh/version.py index 891d8cf..aeeaccb 100644 --- a/asyncssh/version.py +++ b/asyncssh/version.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -26,4 +26,4 @@ __url__ = 'http://asyncssh.timeheart.net' -__version__ = '2.10.1' +__version__ = '2.21.1' diff --git a/asyncssh/x11.py b/asyncssh/x11.py index 7765b64..780ed55 100644 --- a/asyncssh/x11.py +++ b/asyncssh/x11.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016-2021 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -242,7 +242,7 @@ def data_received(self, data: bytes, datatype: DataType = None) -> None: if self._recv_handler: self._inpbuf += data - while self._recv_handler: + while self._recv_handler: # type: ignore[truthy-function] if len(self._inpbuf) >= self._bytes_needed: data = self._inpbuf[:self._bytes_needed] self._inpbuf = self._inpbuf[self._bytes_needed:] @@ -345,7 +345,7 @@ def attach(self, chan: 'SSHChannel', screen: int) -> str: self._channels.add(chan) - return '%s.%s' % (self._display, screen) + return f'{self._display}.{screen}' def detach(self, chan: 'SSHChannel') -> bool: """Detach a channel from this listener""" @@ -534,7 +534,7 @@ async def create_x11_server_listener(conn: 'SSHServerConnection', except OSError: continue - display = '%s:%d' % (X11_LISTEN_HOST, dpynum) + display = f'{X11_LISTEN_HOST}:{dpynum}' try: await update_xauth(auth_path, X11_LISTEN_HOST, str(dpynum), diff --git a/debian/.gitignore b/debian/.gitignore new file mode 100644 index 0000000..2c8afeb --- /dev/null +++ b/debian/.gitignore @@ -0,0 +1 @@ +/files diff --git a/debian/changelog b/debian/changelog index fc6cca5..d7a9f6f 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,83 @@ +python-asyncssh (2.21.1-1) unstable; urgency=medium + + * Team upload. + * New upstream release. + + -- Colin Watson Tue, 30 Sep 2025 13:32:07 +0100 + +python-asyncssh (2.21.0-1) unstable; urgency=medium + + * Team upload. + * New upstream release. + + -- Colin Watson Wed, 13 Aug 2025 11:05:27 +0100 + +python-asyncssh (2.20.0-1) unstable; urgency=medium + + * Team upload. + * New upstream release. + + -- Colin Watson Mon, 24 Feb 2025 18:29:44 +0000 + +python-asyncssh (2.19.0-1) unstable; urgency=medium + + * Team upload. + * New upstream release. + + -- Colin Watson Fri, 13 Dec 2024 13:59:18 +0000 + +python-asyncssh (2.18.0-1) unstable; urgency=medium + + * Team upload. + * New upstream release. + + -- Colin Watson Mon, 28 Oct 2024 12:28:19 +0000 + +python-asyncssh (2.17.0-2) unstable; urgency=medium + + * Team upload. + * Use bcrypt rather than crypt in simple_server example (closes: + #1084679). + + -- Colin Watson Sun, 13 Oct 2024 15:02:06 +0100 + +python-asyncssh (2.17.0-1) unstable; urgency=medium + + * Team upload. + * New upstream release. + + -- Colin Watson Fri, 06 Sep 2024 11:14:21 +0100 + +python-asyncssh (2.16.0-1) unstable; urgency=medium + + * Team upload. + * New upstream release. + + -- Colin Watson Mon, 19 Aug 2024 11:59:06 +0100 + +python-asyncssh (2.15.0-1) unstable; urgency=medium + + * Team upload. + * New upstream release (closes: #1076423): + - Hide cryptography 37.0.0 deprecation warnings (closes: #1069811). + - CVE-2023-48795: Implemented "strict kex" support and other + countermeasures to protect against the Terrapin Attack (closes: + #1059007). + - CVE-2023-46445, CVE-2023-46446: Hardened AsyncSSH state machine + against potential message injection attacks (closes: #1055999, + #1056000). + * Build-depend on openssl-provider-legacy where available; some tests need + it. + * Drop "Make Sphinx use default theme" and "Revert fido 0.9.2 support" + patches, as the relevant dependencies have since been upgraded. + * Deduplicate results from getaddrinfo (closes: #1052788). + * Enable PKCS#11 tests at build time, since python3-pkcs11 is now + packaged. + * Use pybuild-plugin-pyproject. + * Run tests using pytest. + + -- Colin Watson Sun, 18 Aug 2024 12:25:04 +0100 + python-asyncssh (2.10.1-2) unstable; urgency=medium * Team Upload. diff --git a/debian/control b/debian/control index 63a04e7..7e9effc 100644 --- a/debian/control +++ b/debian/control @@ -5,22 +5,24 @@ Maintainer: Debian Python Team Uploaders: Vincent Bernat Build-Depends: debhelper-compat (= 12), + dh-python, localehelper, + openssh-client, + openssl, + openssl-provider-legacy | openssl (<< 3.3.1-5~), + pybuild-plugin-pyproject, python3-all, - python3-sphinx (>= 1.0.7+dfsg-1~), - python3-setuptools, - python3-cryptography, python3-bcrypt (>= 3.1.3) , + python3-cryptography (>= 39.0), python3-fido2 , python3-gssapi , python3-libnacl , python3-openssl , + python3-pkcs11 , + python3-pytest , + python3-setuptools, + python3-sphinx (>= 1.0.7+dfsg-1~), python3-typing-extensions , -# not packaged -# python3-pkcs11 , - openssl, - openssh-client, - dh-python Standards-Version: 4.5.1 Homepage: https://github.com/ronf/asyncssh Vcs-Browser: https://salsa.debian.org/python-team/packages/python-asyncssh diff --git a/debian/copyright b/debian/copyright index 9ae9854..687f69c 100644 --- a/debian/copyright +++ b/debian/copyright @@ -4,7 +4,7 @@ Upstream-Contact: Ron Frederick Source: https://github.com/ronf/asyncssh Files: * -Copyright: Copyright (c) 2013-2022 by Ron Frederick and others. +Copyright: Copyright (c) 2013-2025 by Ron Frederick and others. License: EPL-1.0 Files: debian/* diff --git a/debian/patches/0003-Revert-fido-0.9.2-support.patch b/debian/patches/0003-Revert-fido-0.9.2-support.patch deleted file mode 100644 index 7b2c1b8..0000000 --- a/debian/patches/0003-Revert-fido-0.9.2-support.patch +++ /dev/null @@ -1,52 +0,0 @@ -From: Stefano Rivera -Date: Wed, 18 May 2022 09:03:00 -0400 -Subject: Revert fido 0.9.2 support - -Not yet in Debian, we're still on 0.9.1 - -This reverts: 7a4597953a631ee5091ac1b6e384e32d4f018a82 ---- - asyncssh/sk.py | 2 +- - setup.py | 2 +- - tests/sk_stub.py | 2 +- - 3 files changed, 3 insertions(+), 3 deletions(-) - -diff --git a/asyncssh/sk.py b/asyncssh/sk.py -index 969bff9..56510b0 100644 ---- a/asyncssh/sk.py -+++ b/asyncssh/sk.py -@@ -167,7 +167,7 @@ def sk_enroll(alg: int, application: bytes, user: str, - try: - return _ctap2_enroll(dev, alg, application, user, pin, resident) - except CtapError as exc: -- if exc.code == CtapError.ERR.PUAT_REQUIRED: -+ if exc.code == CtapError.ERR.PIN_REQUIRED: - raise ValueError('PIN required') from None - elif exc.code == CtapError.ERR.PIN_INVALID: - raise ValueError('Invalid PIN') from None -diff --git a/setup.py b/setup.py -index b9df796..5d773bf 100755 ---- a/setup.py -+++ b/setup.py -@@ -60,7 +60,7 @@ setup(name = 'asyncssh', - install_requires = ['cryptography >= 3.1', 'typing_extensions >= 3.6'], - extras_require = { - 'bcrypt': ['bcrypt >= 3.1.3'], -- 'fido2': ['fido2 >= 0.9.2'], -+ 'fido2': ['fido2 == 0.9.1'], - 'gssapi': ['gssapi >= 1.2.0'], - 'libnacl': ['libnacl >= 1.4.2'], - 'pkcs11': ['python-pkcs11 >= 0.7.0'], -diff --git a/tests/sk_stub.py b/tests/sk_stub.py -index ffba14c..aeb6b88 100644 ---- a/tests/sk_stub.py -+++ b/tests/sk_stub.py -@@ -195,7 +195,7 @@ class Ctap2(_CtapStub): - if self.dev.error == 'err': - raise CtapError(CtapError.ERR.INVALID_CREDENTIAL) - elif self.dev.error == 'pinreq': -- raise CtapError(CtapError.ERR.PUAT_REQUIRED) -+ raise CtapError(CtapError.ERR.PIN_REQUIRED) - elif self.dev.error == 'badpin': - raise CtapError(CtapError.ERR.PIN_INVALID) - diff --git a/debian/patches/0004-Handle-ConnectionRefusedError-when-connecting-to-223.patch b/debian/patches/0004-Handle-ConnectionRefusedError-when-connecting-to-223.patch deleted file mode 100644 index 3e50758..0000000 --- a/debian/patches/0004-Handle-ConnectionRefusedError-when-connecting-to-223.patch +++ /dev/null @@ -1,58 +0,0 @@ -From: Stefano Rivera -Date: Wed, 18 May 2022 09:21:27 -0400 -Subject: Handle ConnectionRefusedError when connecting to 223.255.255.254 - -If the tests are run from an environment with a firewall, they may be -refused instead of timing out. - -Just skip the test. - -Forwarded: https://github.com/ronf/asyncssh/pull/480 ---- - tests/test_connection.py | 23 ++++++++++++++++------- - 1 file changed, 16 insertions(+), 7 deletions(-) - -diff --git a/tests/test_connection.py b/tests/test_connection.py -index 9a3871c..9eec850 100644 ---- a/tests/test_connection.py -+++ b/tests/test_connection.py -@@ -425,23 +425,32 @@ class _TestConnection(ServerTestCase): - async def test_connect_timeout_exceeded(self): - """Test connect timeout exceeded""" - -- with self.assertRaises(asyncio.TimeoutError): -- await asyncssh.connect('223.255.255.254', connect_timeout=1) -+ try: -+ with self.assertRaises(asyncio.TimeoutError): -+ await asyncssh.connect('223.255.255.254', connect_timeout=1) -+ except ConnectionRefusedError: -+ raise unittest.SkipTest("Outboand connection firewalled") - - @asynctest - async def test_connect_timeout_exceeded_string(self): - """Test connect timeout exceeded with string value""" - -- with self.assertRaises(asyncio.TimeoutError): -- await asyncssh.connect('223.255.255.254', connect_timeout='0m1s') -+ try: -+ with self.assertRaises(asyncio.TimeoutError): -+ await asyncssh.connect('223.255.255.254', connect_timeout='0m1s') -+ except ConnectionRefusedError: -+ raise unittest.SkipTest("Outboand connection firewalled") - - @asynctest - async def test_connect_timeout_exceeded_tunnel(self): - """Test connect timeout exceeded""" - -- with self.assertRaises(asyncio.TimeoutError): -- await asyncssh.listen(server_host_keys=['skey'], -- tunnel='223.255.255.254', connect_timeout=1) -+ try: -+ with self.assertRaises(asyncio.TimeoutError): -+ await asyncssh.listen(server_host_keys=['skey'], -+ tunnel='223.255.255.254', connect_timeout=1) -+ except ConnectionRefusedError: -+ raise unittest.SkipTest("Outboand connection firewalled") - - @asynctest - async def test_invalid_connect_timeout(self): diff --git a/debian/patches/mock-pathlib-expanduser.patch b/debian/patches/mock-pathlib-expanduser.patch deleted file mode 100644 index 24537b8..0000000 --- a/debian/patches/mock-pathlib-expanduser.patch +++ /dev/null @@ -1,30 +0,0 @@ -From 2ff28a51439ec687be687f9b3d204316e60cabcd Mon Sep 17 00:00:00 2001 -From: Georg Sauthoff -Date: Sat, 9 Jul 2022 17:22:55 +0200 -Subject: [PATCH] Also patch pathlib expanduser - -NB: with recent Python versions the existing `os.path.expanduser()` patch -also affects `pathlib.path.expanduser()` which is invoked by the config -parser for expanding `~/.ssh`. ---- - tests/test_config.py | 6 +++++- - 1 file changed, 5 insertions(+), 1 deletion(-) - -diff --git a/tests/test_config.py b/tests/test_config.py -index 31f055ec..ce3083e3 100644 ---- a/tests/test_config.py -+++ b/tests/test_config.py -@@ -417,8 +417,12 @@ def mock_expanduser(path): - - return path - -+ def mock_pathlib_expanduser(s): -+ return s._from_parts([os.environ['HOME']] + s._parts[1:]) -+ - with self.assertRaises(asyncssh.ConfigParseError): -- with patch('os.path.expanduser', mock_expanduser): -+ with patch('os.path.expanduser', mock_expanduser), \ -+ patch('pathlib.Path.expanduser', mock_pathlib_expanduser): - self._parse_config('RemoteCommand %d') - - def test_uid_percent_expansion_unavailable(self): diff --git a/debian/patches/series b/debian/patches/series index 9618494..79e9c29 100644 --- a/debian/patches/series +++ b/debian/patches/series @@ -1,5 +1 @@ -sphinx-use-default-theme.patch -0002-skip-tests-requiring-network-access.patch -0003-Revert-fido-0.9.2-support.patch -0004-Handle-ConnectionRefusedError-when-connecting-to-223.patch -mock-pathlib-expanduser.patch +skip-tests-requiring-network-access.patch diff --git a/debian/patches/0002-skip-tests-requiring-network-access.patch b/debian/patches/skip-tests-requiring-network-access.patch similarity index 87% rename from debian/patches/0002-skip-tests-requiring-network-access.patch rename to debian/patches/skip-tests-requiring-network-access.patch index 130a0b3..805f18e 100644 --- a/debian/patches/0002-skip-tests-requiring-network-access.patch +++ b/debian/patches/skip-tests-requiring-network-access.patch @@ -1,13 +1,13 @@ From: Vincent Bernat Date: Sun, 3 Jan 2016 18:11:46 +0100 -Subject: skip tests requiring network access +Subject: Skip tests requiring network access --- tests/test_auth_keys.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_auth_keys.py b/tests/test_auth_keys.py -index 27d4190..eba49e1 100644 +index 9e1a05c..da8b3d7 100644 --- a/tests/test_auth_keys.py +++ b/tests/test_auth_keys.py @@ -95,6 +95,7 @@ class _TestAuthorizedKeys(TempDirTestCase): diff --git a/debian/patches/sphinx-use-default-theme.patch b/debian/patches/sphinx-use-default-theme.patch deleted file mode 100644 index 15cecdf..0000000 --- a/debian/patches/sphinx-use-default-theme.patch +++ /dev/null @@ -1,35 +0,0 @@ -From: SVN-Git Migration -Date: Thu, 8 Oct 2015 11:09:41 -0700 -Subject: make Sphinx use default theme - -Forwarded: not-needed - -The "classic" theme is introduced in Sphinx 1.3 and not available in -Sphinx 1.2. - -Patch-Name: sphinx-use-default-theme.patch ---- - docs/rftheme/static/rftheme.css_t | 2 +- - docs/rftheme/theme.conf | 2 +- - 2 files changed, 2 insertions(+), 2 deletions(-) - -diff --git a/docs/rftheme/static/rftheme.css_t b/docs/rftheme/static/rftheme.css_t -index 66aad65..77fe744 100644 ---- a/docs/rftheme/static/rftheme.css_t -+++ b/docs/rftheme/static/rftheme.css_t -@@ -1,4 +1,4 @@ --@import url("classic.css"); -+@import url("default.css"); - - .tight-list * { - line-height: 110% !important; -diff --git a/docs/rftheme/theme.conf b/docs/rftheme/theme.conf -index 1c2b15e..a417128 100644 ---- a/docs/rftheme/theme.conf -+++ b/docs/rftheme/theme.conf -@@ -1,4 +1,4 @@ - [theme] --inherit = classic -+inherit = default - stylesheet = rftheme.css - pygments_style = sphinx diff --git a/debian/rules b/debian/rules index 6545d32..e499088 100755 --- a/debian/rules +++ b/debian/rules @@ -14,4 +14,4 @@ override_dh_installdocs: override_dh_auto_test: env RES_OPTIONS=attempts:0 localehelper LANG=en_US.UTF-8 \ - dh_auto_test -- --system=custom --test-args='{interpreter} -m unittest discover -v' + dh_auto_test diff --git a/debian/tests/autopkgtest-pkg-python.conf b/debian/tests/autopkgtest-pkg-python.conf new file mode 100644 index 0000000..2dc3bc9 --- /dev/null +++ b/debian/tests/autopkgtest-pkg-python.conf @@ -0,0 +1,2 @@ +extra_depends = \ + openssl-provider-legacy | openssl (<< 3.3.1-5~), diff --git a/docs/api.rst b/docs/api.rst index 7063595..78de0d0 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -37,20 +37,23 @@ Once an SSH client connection is established and authentication is successful, multiple simultaneous channels can be opened on it. This is accomplished calling methods such as :meth:`create_session() `, :meth:`create_connection() -`, and :meth:`create_unix_connection() -` on the -:class:`SSHClientConnection` object. The client can also set up listeners on -remote TCP ports and UNIX domain sockets by calling :meth:`create_server() -` and :meth:`create_unix_server() -`. All of these methods take -``session_factory`` arguments that return :class:`SSHClientSession`, -:class:`SSHTCPSession`, or :class:`SSHUNIXSession` objects used to manage -the channels once they are open. Alternately, channels can be opened using -:meth:`open_session() `, -:meth:`open_connection() `, or +`, :meth:`create_unix_connection() +`, :meth:`create_tun() +`, and :meth:`create_tap() +` on the :class:`SSHClientConnection` object. +The client can also set up listeners on remote TCP ports and UNIX domain +sockets by calling :meth:`create_server() ` +and :meth:`create_unix_server() `. +All of these methods take ``session_factory`` arguments that return +:class:`SSHClientSession`, :class:`SSHTCPSession`, or :class:`SSHUNIXSession` +objects used to manage the channels once they are open. Alternately, channels +can be opened using :meth:`open_session() `, +:meth:`open_connection() `, :meth:`open_unix_connection() `, -which return :class:`SSHReader` and :class:`SSHWriter` objects that can be -used to perform I/O on the channel. The methods :meth:`start_server() +:meth:`open_tun() `, or +:meth:`open_tap() `, which return +:class:`SSHReader` and :class:`SSHWriter` objects that can be used to +perform I/O on the channel. The methods :meth:`start_server() ` and :meth:`start_unix_server() ` can be used to set up listeners on remote TCP ports or UNIX domain sockets and get back these :class:`SSHReader` @@ -75,15 +78,31 @@ The client can also set up TCP port forwarding by calling :meth:`forward_remote_port() ` and UNIX domain socket forwarding by calling :meth:`forward_local_path() ` or :meth:`forward_remote_path() -`. In these cases, data transfer on -the channels is managed automatically by AsyncSSH whenever new connections -are opened, so custom session objects are not required. +`. Mixed forwarding from a TCP port +to a UNIX domain socket or vice-versa can be set up using the functions +:meth:`forward_local_port_to_path() +`, +:meth:`forward_local_path_to_port() +`, +:meth:`forward_remote_port_to_path() +`, and +:meth:`forward_remote_path_to_port() +`. In these cases, data +transfer on the channels is managed automatically by AsyncSSH whenever new +connections are opened, so custom session objects are not required. Dynamic TCP port forwarding can be set up by calling :meth:`forward_socks() `. The SOCKS listener set up by AsyncSSH on the requested port accepts SOCKS connect requests and is compatible with SOCKS versions 4, 4a, and 5. +Bidirectional packet forwarding at layer 2 or 3 is also supported using +the functions :meth:`forward_tun() ` and +:meth:`forward_tap() ` to set up tunnels +between local and remote TUN or TAP interfaces. Once a tunnel is established, +packets arriving on TUN/TAP interfaces on either side are sent over the +tunnel and automatically sent out the TUN/TAP interface on the other side. + When an SSH server receives a new connection and authentication is successful, handlers such as :meth:`session_requested() `, :meth:`connection_requested() `, @@ -125,57 +144,21 @@ found under :ref:`Constants`. Main Functions ============== -connect -------- - .. autofunction:: connect - -connect_reverse ---------------- - .. autofunction:: connect_reverse - -listen ------- - .. autofunction:: listen - -listen_reverse --------------- - .. autofunction:: listen_reverse - -create_connection ------------------ - +.. autofunction:: run_client +.. autofunction:: run_server .. autofunction:: create_connection - -create_server -------------- - .. autofunction:: create_server - -get_server_host_key -------------------- - .. autofunction:: get_server_host_key - -get_server_auth_methods ------------------------ - .. autofunction:: get_server_auth_methods - -scp ---- - .. autofunction:: scp Main Classes ============ -SSHClient ---------- - .. autoclass:: SSHClient ================================== = @@ -222,9 +205,6 @@ SSHClient .. automethod:: kbdint_challenge_received ============================================ = -SSHServer ---------- - .. autoclass:: SSHServer ================================== = @@ -289,14 +269,13 @@ SSHServer .. automethod:: unix_connection_requested .. automethod:: server_requested .. automethod:: unix_server_requested + .. automethod:: tun_requested + .. automethod:: tap_requested ========================================= = Connection Classes ================== -SSHClientConnection -------------------- - .. autoclass:: SSHClientConnection() ======================================================================= = @@ -313,6 +292,7 @@ SSHClientConnection .. automethod:: set_keepalive .. automethod:: get_server_host_key .. automethod:: send_debug + .. automethod:: is_closed =================================== = ====================================================================================================================================================== = @@ -342,18 +322,27 @@ SSHClientConnection .. automethod:: open_unix_connection .. automethod:: create_unix_server .. automethod:: start_unix_server + .. automethod:: create_tun + .. automethod:: create_tap + .. automethod:: open_tun + .. automethod:: open_tap ====================================== = - =================================== = + =========================================== = Client forwarding methods - =================================== = - .. automethod:: forward_connection + =========================================== = .. automethod:: forward_local_port .. automethod:: forward_local_path + .. automethod:: forward_local_port_to_path + .. automethod:: forward_local_path_to_port .. automethod:: forward_remote_port .. automethod:: forward_remote_path + .. automethod:: forward_remote_port_to_path + .. automethod:: forward_remote_path_to_port .. automethod:: forward_socks - =================================== = + .. automethod:: forward_tun + .. automethod:: forward_tap + =========================================== = =========================== = Connection close methods @@ -364,9 +353,6 @@ SSHClientConnection .. automethod:: wait_closed =========================== = -SSHServerConnection -------------------- - .. autoclass:: SSHServerConnection() ======================================================================= = @@ -382,6 +368,7 @@ SSHServerConnection .. automethod:: set_extra_info .. automethod:: set_keepalive .. automethod:: send_debug + .. automethod:: is_closed ============================== = ============================================ = @@ -404,19 +391,13 @@ SSHServerConnection .. automethod:: open_unix_connection ====================================== = - ======================================= = - Server forwarding methods - ======================================= = - .. automethod:: forward_connection - .. automethod:: forward_unix_connection - ======================================= = - ===================================== = Server channel creation methods ===================================== = .. automethod:: create_server_channel .. automethod:: create_tcp_channel .. automethod:: create_unix_channel + .. automethod:: create_tuntap_channel ===================================== = =========================== = @@ -428,22 +409,13 @@ SSHServerConnection .. automethod:: wait_closed =========================== = -SSHClientConnectionOptions --------------------------- - .. autoclass:: SSHClientConnectionOptions() -SSHServerConnectionOptions --------------------------- - .. autoclass:: SSHServerConnectionOptions() Process Classes =============== -SSHClientProcess ----------------- - .. autoclass:: SSHClientProcess ======================================================================= = @@ -485,9 +457,6 @@ SSHClientProcess .. automethod:: wait_closed ======================================================================= = -SSHServerProcess ----------------- - .. autoclass:: SSHServerProcess ============================== = @@ -523,14 +492,8 @@ SSHServerProcess .. automethod:: wait_closed ================================ = -SSHCompletedProcess -------------------- - .. autoclass:: SSHCompletedProcess() -SSHSubprocessReadPipe ---------------------- - .. autoclass:: SSHSubprocessReadPipe() ==================================== = @@ -552,9 +515,6 @@ SSHSubprocessReadPipe .. automethod:: close ======================================================================= = -SSHSubprocessWritePipe ----------------------- - .. autoclass:: SSHSubprocessWritePipe() ==================================== = @@ -581,9 +541,6 @@ SSHSubprocessWritePipe .. automethod:: close ======================================================================= = -SSHSubprocessProtocol ---------------------- - .. autoclass:: SSHSubprocessProtocol ==================================== = @@ -605,9 +562,6 @@ SSHSubprocessProtocol .. automethod:: process_exited ================================== = -SSHSubprocessTransport ----------------------- - .. autoclass:: SSHSubprocessTransport ==================================== = @@ -635,9 +589,6 @@ SSHSubprocessTransport Session Classes =============== -SSHClientSession ----------------- - .. autoclass:: SSHClientSession =============================== = @@ -670,9 +621,6 @@ SSHClientSession .. automethod:: exit_signal_received ==================================== = -SSHServerSession ----------------- - .. autoclass:: SSHServerSession =============================== = @@ -714,9 +662,6 @@ SSHServerSession .. automethod:: terminal_size_changed ===================================== = -SSHTCPSession -------------- - .. autoclass:: SSHTCPSession =============================== = @@ -741,9 +686,6 @@ SSHTCPSession .. automethod:: resume_writing ============================== = -SSHUNIXSession --------------- - .. autoclass:: SSHUNIXSession =============================== = @@ -768,12 +710,33 @@ SSHUNIXSession .. automethod:: resume_writing ============================== = +.. autoclass:: SSHTunTapSession + + =============================== = + General session handlers + =============================== = + .. automethod:: connection_made + .. automethod:: connection_lost + .. automethod:: session_started + =============================== = + + ============================= = + General session read handlers + ============================= = + .. automethod:: data_received + .. automethod:: eof_received + ============================= = + + ============================== = + General session write handlers + ============================== = + .. automethod:: pause_writing + .. automethod:: resume_writing + ============================== = + Channel Classes =============== -SSHClientChannel ----------------- - .. autoclass:: SSHClientChannel() ========================= = @@ -832,9 +795,6 @@ SSHClientChannel .. automethod:: wait_closed ============================= = -SSHServerChannel ----------------- - .. autoclass:: SSHServerChannel() ======================================================================= = @@ -901,9 +861,6 @@ SSHServerChannel .. automethod:: wait_closed ============================= = -SSHLineEditorChannel --------------------- - .. autoclass:: SSHLineEditorChannel() ============================== = @@ -915,9 +872,6 @@ SSHLineEditorChannel .. automethod:: set_echo ============================== = -SSHTCPChannel -------------- - .. autoclass:: SSHTCPChannel() ======================================================================= = @@ -960,9 +914,6 @@ SSHTCPChannel .. automethod:: wait_closed ============================= = -SSHUNIXChannel --------------- - .. autoclass:: SSHUNIXChannel() ======================================================================= = @@ -1005,21 +956,72 @@ SSHUNIXChannel .. automethod:: wait_closed ============================= = +.. autoclass:: SSHTunTapChannel() + + ======================================================================= = + Channel attributes + ======================================================================= = + .. autoattribute:: logger + ======================================================================= = + + ============================== = + General channel info methods + ============================== = + .. automethod:: get_extra_info + .. automethod:: set_extra_info + ============================== = + + ============================== = + General channel read methods + ============================== = + .. automethod:: pause_reading + .. automethod:: resume_reading + ============================== = + + ======================================= = + General channel write methods + ======================================= = + .. automethod:: can_write_eof + .. automethod:: get_write_buffer_size + .. automethod:: set_write_buffer_limits + .. automethod:: write + .. automethod:: writelines + .. automethod:: write_eof + ======================================= = + + ============================= = + General channel close methods + ============================= = + .. automethod:: abort + .. automethod:: close + .. automethod:: is_closing + .. automethod:: wait_closed + ============================= = + +Forwarder Classes +================= + +.. autoclass:: SSHForwarder() + + ============================== = + .. automethod:: get_extra_info + .. automethod:: close + ============================== = + + Listener Classes ================ -SSHAcceptor ------------ - .. autoclass:: SSHAcceptor() - ====================== = + ============================= = + .. automethod:: get_addresses + .. automethod:: get_port + .. automethod:: close + .. automethod:: wait_closed .. automethod:: update - ====================== = - + ============================= = -SSHListener ------------ .. autoclass:: SSHListener() =========================== = @@ -1031,9 +1033,6 @@ SSHListener Stream Classes ============== -SSHReader ---------- - .. autoclass:: SSHReader() ============================== = @@ -1049,9 +1048,6 @@ SSHReader .. automethod:: readexactly ============================== = -SSHWriter ---------- - .. autoclass:: SSHWriter() ============================== = @@ -1071,28 +1067,28 @@ SSHWriter SFTP Support ============ -SFTPClient ----------- - .. autoclass:: SFTPClient() - ======================================================================= = + ======================================= = SFTP client attributes - ======================================================================= = + ======================================= = .. autoattribute:: logger .. autoattribute:: version - ======================================================================= = + .. autoattribute:: limits + .. autoattribute:: supports_remote_copy + ======================================= = - ===================== = + =========================== = File transfer methods - ===================== = + =========================== = .. automethod:: get .. automethod:: put .. automethod:: copy .. automethod:: mget .. automethod:: mput .. automethod:: mcopy - ===================== = + .. automethod:: remote_copy + =========================== = ============================================================================================================================================================================================================================== = File access methods @@ -1147,6 +1143,7 @@ SFTPClient .. automethod:: readdir .. automethod:: listdir .. automethod:: glob + .. automethod:: glob_sftpname ================================================= = =========================== = @@ -1156,13 +1153,11 @@ SFTPClient .. automethod:: wait_closed =========================== = -SFTPClientFile --------------- - .. autoclass:: SFTPClientFile() ================================================= = .. automethod:: read + .. automethod:: read_parallel .. automethod:: write .. automethod:: seek(offset, from_what=SEEK_SET) .. automethod:: tell @@ -1179,9 +1174,6 @@ SFTPClientFile .. automethod:: close ================================================= = -SFTPServer ----------- - .. autoclass:: SFTPServer ============================= = @@ -1237,9 +1229,9 @@ SFTPServer ======================== = Directory access methods ======================== = - .. automethod:: listdir .. automethod:: mkdir .. automethod:: rmdir + .. automethod:: scandir ======================== = ===================== = @@ -1248,21 +1240,14 @@ SFTPServer .. automethod:: exit ===================== = -SFTPAttrs ---------- - .. autoclass:: SFTPAttrs() -SFTPVFSAttrs ------------- - .. autoclass:: SFTPVFSAttrs() -SFTPName --------- - .. autoclass:: SFTPName() +.. autoclass:: SFTPLimits() + .. index:: Public key and certificate support .. _PublicKeySupport: @@ -1330,7 +1315,7 @@ which is trusted by the remote system. Instead of passing tuples of keys and certificates or relying on file naming conventions for certificates, you also have the option of -providing a list of keys and a seperate list of certificates. In this +providing a list of keys and a separate list of certificates. In this case, AsyncSSH will automatically match up the keys with their associated certificates when they are present. @@ -1386,7 +1371,7 @@ with PKIX-SSH, which adds X.509 certificate support to OpenSSH. To specify a subject name pattern instead of a specific certificate, base64-encoded certificate data should be replaced with the string -'Subject:' followed by a a comma-separated list of X.509 relative +'Subject:' followed by a comma-separated list of X.509 relative distinguished name components. AsyncSSH extends the PKIX-SSH syntax to also support matching on a @@ -1459,8 +1444,8 @@ methods. These values can be specified in any of the following ways: to refer to times in the past or positive to refer to times in the future. -SSHKey ------- +Key and certificate classes/functions +------------------------------------- .. autoclass:: SSHKey() @@ -1484,9 +1469,6 @@ SSHKey .. automethod:: append_public_key ============================================== = -SSHKeyPair ----------- - .. autoclass:: SSHKeyPair() ================================= = @@ -1498,9 +1480,6 @@ SSHKeyPair .. automethod:: set_comment ================================= = -SSHCertificate --------------- - .. autoclass:: SSHCertificate() ================================== = @@ -1513,82 +1492,25 @@ SSHCertificate .. automethod:: append_certificate ================================== = -generate_private_key --------------------- - .. autofunction:: generate_private_key - -import_private_key ------------------- - .. autofunction:: import_private_key - -import_public_key ------------------ - .. autofunction:: import_public_key - -import_certificate ------------------- - .. autofunction:: import_certificate - -read_private_key ----------------- - .. autofunction:: read_private_key - -read_public_key ---------------- - .. autofunction:: read_public_key - -read_certificate ----------------- - .. autofunction:: read_certificate - -read_private_key_list ---------------------- - .. autofunction:: read_private_key_list - -read_public_key_list --------------------- - .. autofunction:: read_public_key_list - -read_certificate_list ---------------------- - .. autofunction:: read_certificate_list - -load_keypairs -------------- - .. autofunction:: load_keypairs - -load_public_keys ----------------- - .. autofunction:: load_public_keys - -load_certificates ------------------ - .. autofunction:: load_certificates - -load_pkcs11_keys ----------------- - .. autofunction:: load_pkcs11_keys - -load_resident_keys ------------------- - .. autofunction:: load_resident_keys +.. autofunction:: set_default_skip_rsa_key_validation .. index:: SSH agent support +.. _SSHAgentSupport: SSH Agent Support ================= @@ -1622,9 +1544,6 @@ path to a UNIX domain socket which can be passed as the ``SSH_AUTH_SOCK`` to local applications which need this access. Any requests sent to this socket are forwarded over the SSH connection to the client's ssh-agent. -SSHAgentClient --------------- - .. autoclass:: SSHAgentClient() ===================================== = @@ -1641,9 +1560,6 @@ SSHAgentClient .. automethod:: wait_closed ===================================== = -SSHAgentKeyPair ---------------- - .. autoclass:: SSHAgentKeyPair() ================================= = @@ -1655,9 +1571,6 @@ SSHAgentKeyPair .. automethod:: remove ================================= = -connect_agent -------------- - .. autofunction:: connect_agent .. index:: Config file support @@ -1689,6 +1602,11 @@ The following OpenSSH client config options are currently supported: | AddressFamily | BindAddress + | CanonicalDomains + | CanonicalizeFallbackLocal + | CanonicalizeHostname + | CanonicalizeMaxDots + | CanonicalizePermittedCNAMEs | CASignatureAlgorithms | CertificateFile | ChallengeResponseAuthentication @@ -1731,7 +1649,9 @@ The following OpenSSH client config options are currently supported: For the "Match" conditional, the following criteria are currently supported: | All + | Canonical | Exec + | Final | Host | LocalUser | OriginalHost @@ -1746,6 +1666,10 @@ For the "Match" conditional, the following criteria are currently supported: when options objects are created by AsyncSSH APIs such as :func:`connect` and :func:`listen`. +Match criteria can be negated by prefixing the criteria name with '!'. +This will negate the criteria and causing the match block to be evaluated +only if the negated criteria all fail to match. + The following client config token expansions are currently supported: .. table:: @@ -1786,6 +1710,11 @@ The following OpenSSH server config options are currently supported: | AuthorizedKeysFile | AllowAgentForwarding | BindAddress + | CanonicalDomains + | CanonicalizeFallbackLocal + | CanonicalizeHostname + | CanonicalizeMaxDots + | CanonicalizePermittedCNAMEs | CASignatureAlgorithms | ChallengeResponseAuthentication | Ciphers @@ -1813,7 +1742,9 @@ The following OpenSSH server config options are currently supported: For the "Match" conditional, the following criteria are currently supported: | All + | Canonical | Exec + | Final | Address | Host | LocalAddress @@ -1829,6 +1760,9 @@ For the "Match" conditional, the following criteria are currently supported: when options objects are created by AsyncSSH APIs such as :func:`connect` and :func:`listen`. +Match criteria can be negated by prefixing the criteria name with '!'. +This will negate the criteria and causing the match block to be evaluated +only if the negated criteria all fail to match. The following server config token expansions are currently supported: .. table:: @@ -1915,8 +1849,8 @@ which can be provided, :ref:`SpecifyingCertificates` for the allowed form of certificates, and :ref:`SpecifyingX509Subjects` for the allowed form of X.509 subject names. -SSHKnownHosts -------------- +Known hosts classes/functions +----------------------------- .. autoclass:: SSHKnownHosts() @@ -1924,20 +1858,8 @@ SSHKnownHosts .. automethod:: match ===================== = -import_known_hosts ------------------- - .. autofunction:: import_known_hosts - -read_known_hosts ----------------- - .. autofunction:: read_known_hosts - - -match_known_hosts ------------------ - .. autofunction:: match_known_hosts .. index:: Authorized keys @@ -1975,19 +1897,12 @@ in OpenSSH format or X.509 certificate subject names. See :ref:`SpecifyingX509Subjects` for more information on using subject names in place of specific X.509 certificates. -SSHAuthorizedKeys ------------------ +Authorized keys classes/functions +--------------------------------- .. autoclass:: SSHAuthorizedKeys() -import_authorized_keys ----------------------- - .. autofunction:: import_authorized_keys - -read_authorized_keys --------------------- - .. autofunction:: read_authorized_keys .. index:: Logging @@ -2014,19 +1929,8 @@ be used by application code to output custom log information associated with a particular connection or channel. Logger objects are also provided as members of SFTP client and server objects. -set_log_level -------------- - .. autofunction:: set_log_level - -set_sftp_log_level ------------------- - .. autofunction:: set_sftp_log_level - -set_debug_level ---------------- - .. autofunction:: set_debug_level .. index:: Exceptions @@ -2035,29 +1939,10 @@ set_debug_level Exceptions ========== -PasswordChangeRequired ----------------------- - .. autoexception:: PasswordChangeRequired - -BreakReceived -------------- - .. autoexception:: BreakReceived - -SignalReceived --------------- - .. autoexception:: SignalReceived - -TerminalSizeChanged -------------------- - .. autoexception:: TerminalSizeChanged - -DisconnectError ---------------- - .. autoexception:: DisconnectError .. autoexception:: CompressionError .. autoexception:: ConnectionLost @@ -2069,30 +1954,10 @@ DisconnectError .. autoexception:: ProtocolError .. autoexception:: ProtocolNotSupported .. autoexception:: ServiceNotAvailable - -ChannelOpenError ----------------- - .. autoexception:: ChannelOpenError - -ChannelListenError ------------------- - .. autoexception:: ChannelListenError - -ProcessError ------------- - .. autoexception:: ProcessError - -TimeoutError ------------- - .. autoexception:: TimeoutError - -SFTPError ---------- - .. autoexception:: SFTPError .. autoexception:: SFTPEOFError .. autoexception:: SFTPNoSuchFile @@ -2125,30 +1990,10 @@ SFTPError .. autoexception:: SFTPOwnerInvalid .. autoexception:: SFTPGroupInvalid .. autoexception:: SFTPNoMatchingByteRangeLock - -KeyImportError --------------- - .. autoexception:: KeyImportError - -KeyExportError --------------- - .. autoexception:: KeyExportError - -KeyEncryptionError ------------------- - .. autoexception:: KeyEncryptionError - -KeyGenerationError ------------------- - .. autoexception:: KeyGenerationError - -ConfigParseError ----------------- - .. autoexception:: ConfigParseError .. index:: Supported algorithms @@ -2157,9 +2002,9 @@ ConfigParseError Supported Algorithms ==================== -Algorithms can be specified as either an list of exact algorithm names +Algorithms can be specified as either a list of exact algorithm names or as a string of comma-separated algorithm names that may optionally -include wildcards. A '*' in a name matches zero or more characters and +include wildcards. An '*' in a name matches zero or more characters and a '?' matches exactly one character. When specifying algorithms as a string, it can also be prefixed with '^' @@ -2180,7 +2025,7 @@ by AsyncSSH: | gss-curve25519-sha256 | gss-curve448-sha512 | gss-nistp521-sha512 - | gss-nistp384-sha256 + | gss-nistp384-sha384 | gss-nistp256-sha256 | gss-1.3.132.0.10-sha256 | gss-gex-sha256 @@ -2190,6 +2035,11 @@ by AsyncSSH: | gss-group17-sha512 | gss-group18-sha512 | gss-group14-sha1 + | mlkem768x25519-sha256 + | mlkem768nistp256-sha256 + | mlkem1024nistp384-sha384 + | sntrup761x25519-sha512 + | sntrup761x25519-sha512\@openssh.com | curve25519-sha256 | curve25519-sha256\@libssh.org | curve448-sha512 @@ -2232,6 +2082,9 @@ Curve25519 and Curve448 support is available when OpenSSL 1.1.1 or later is installed. Alternately, Curve25519 is available when the libnacl package and libsodium library are installed. +SNTRUP support is available when the Open Quantum Safe (liboqs) +dynamic library is installed. + .. index:: Encryption algorithms .. _EncryptionAlgs: @@ -2341,6 +2194,7 @@ supported by AsyncSSH: | x509v3-ssh-rsa | sk-ssh-ed25519\@openssh.com | sk-ecdsa-sha2-nistp256\@openssh.com + | webauthn-sk-ecdsa-sha2-nistp256\@openssh.com | ssh-ed25519 | ssh-ed448 | ecdsa-sha2-nistp521 diff --git a/docs/changes.rst b/docs/changes.rst index ea5bb22..fb8d3e4 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -3,9 +3,618 @@ Change Log ========== +Release 2.21.1 (28 Sep 2025) +---------------------------- + +* Added the capability to defer invoking passphrase callback until + an encrypted private key is actually used in a signing operation, + rather than triggering the callback when keys are loaded. This + will only work when a public key is provided with an encrypted + private key either explicitly or as part of the key format (such + as in OpenSSH's private key format). + +* Improved handling of KeyboardInterrupt and task cancellation in + SCP. Thanks go to Viktor Kertesz for reporting this issue and + helping to understand the behavior in various versions of Python. + +* Fixed the env option to support mappings other than dict. Thanks + go to Boris Pavlovic for reporting this issue. + +* Fixed a potential race condition in SSHForwarder cleanup. Thanks + go to GitHub user misa-hase for reporting this issue and helping + to test the fix. + +Release 2.21.0 (2 May 2025) +--------------------------- + +* Added sparse file support for SFTP, allowing file copying which + automatically skips over any "holes" in a source file, transferring + only the data ranges which are actually present. + +* Added support for applications to request that session, connection, + or TUN/TAP requests arriving on an SSHServerConnection be forwarded + out some other established SSHClientConnection. Callback methods on + SSHServer which decide how to handle these requests can now return + an SSHClientConnection to set up this tunneling, instead of having + to accept the request and implement their own forwarding logic. + +* Further hardened the SSH key exchange process to make AsyncSSH + more strict when accepting messages during key exchange. Thanks + go to Fabian Bäumer and Marcus Brinkmann for identifying potential + issues here. + +* Added support for the auth_completed callback in SSHServer to + be either a callable or a coroutine, allowing async operations + to be performed when user authentication completes successfully, + prior to accepting session requests. + +* Added support for the sftp_factory config argument be either a + callable or a coroutine, allowing async operations to be performed + when starting up a new SFTP server session. + +* Fixed a bug where the exit() method of SFTPServer didn't handle + being declared as a coroutine. Thanks go to C. R. Oldham for + reporting this issue. + +* Improved handling of exceptions in connection_lost() callbacks. + Exceptions in connection_lost() will now be reported in the + debug log, but other cleanup code in AsyncSSH will continue, + ignoring those exceptions. Thanks go to Danil Slinchuk for + reporting this issue. + +Release 2.20.0 (17 Feb 2025) +---------------------------- + +* Added support for specifying an explicit path when configuring + agent forwarding. Thanks go to Aleksandr Ilin for pointing out + that this options supports more than just a boolean value. + +* Added support for environment variable expansion in SSH config, + for options which support percent expansion. + +* Added a new begin_auth callback in SSHClient, reporting the + username being sent during SSH client authentication. This can be + useful when the user is conditionally set via an SSH config file. + +* Improved strict-kex interoperability during re-keying. Thanks go + to GitHub user emeryalden for reporting this issue and helping + to track down the source of the problem. + +* Updated SFTP max_requests default to reduce memory usage when + using large block sizes. + +* Updated testing to add Python 3.13 and drop Python 3.7, avoiding + deprecation warnings from the cryptography package. + +* Fixed unit test issues under Windows, allowing unit tests to run + on Windows on all supported versions of Python. + +* Fixed a couple of issues with Python 3.14. Thanks go to Georg + Sauthoff for initially reporting this. + +Release 2.19.0 (12 Dec 2024) +---------------------------- + +* Added support for WebAuthN authentication with U2F security keys, + allowing non-admin Windows users to use these keys for authentication. + Previously, authentication with U2F keys worked on Windows, but only + for admin users. + +* Added support for hostname canonicalization, compatible with the + configuration parameters used in OpenSSH, as well as support for the + "canonical" and "final" match keywords and negation support for + match. Thanks go to GitHub user commonism who suggested this and + provided a proposed implementation for negation. + +* Added client and server support for SFTP copy-data extension and + a new SFTP remote_copy() function which allows data to be moved + between two remote files without downloading and re-uploading the + data. Thanks go to Ali Khosravi for suggesting this addition. + +* Moved project metadata from setup.py to pyproject.toml. Thanks go to + Marc Mueller for contributing this. + +* Updated SSH connection to keep strong references to outstanding + tasks, to avoid potential issues with the garbage collector while + the connection is active. Thanks go to GitHub user Birnendampf for + pointing out this potential issue and suggesting a simple fix. + +* Fixed some issues with block_size argument in SFTP copy functions. + Thanks go to Krzysztof Kotlenga for finding and reporting these issues. + +* Fixed an import error when fido2 package wasn't available. Thanks go + to GitHub user commonism for reporting this issue. + +Release 2.18.0 (26 Oct 2024) +---------------------------- + +* Added support for post-quantum ML-KEM key exchange algorithms, + interoperable with OpenSSH 9.9. + +* Added support for the OpenSSH "limits" extension, allowing the + client to query server limits such as the maximum supported read + and write sizes. The client will automatically default to the reported + maximum size on servers that support this extension. + +* Added more ways to specify environment variables via the `env` option. + Sequences of either 'key=value' strings or (key, value) tuples are now + supported, in addition to a dict. + +* Added support for getting/setting environment variables as byte strings + on platforms which support it. Previously, only Unicode strings were + accepted and they were always encoded on the wire using UTF-8. + +* Added support for non-TCP sockets (such as a socketpair) as the `sock` + parameter in connect calls. Thanks go to Christian Wendt for reporting + this problem and proposing a fix. + +* Changed compression to be disabled by default to avoid it becoming a + performance bottleneck on high-bandwidth connections. This now also + matches the OpenSSH default. + +* Improved speed of parallelized SFTP reads when read-ahead goes beyond + the end of the file. Thanks go to Maximilian Knespel for reporting + this issue and providing performance measurements on the code before + and after the change. + +* Improved cancellation handling during SCP transfers. + +* Improved support for selecting the currently available security key + when the application lists multiple keys to try. Thanks go to GitHub + user zanda8893 for reporting the issue and helping to work out the + details of the problem. + +* Improved handling of reverse DNS failures in host-based authentication. + Thanks go to GitHub user xBiggs for suggesting this change. + +* Improved debug logging of byte strings with non-printable characters. + +* Switched to using an executor on GSSAPI calls to avoid blocking the + event loop. + +* Fixed handling of "UserKnownHostsFile none" in config files. This + previously caused it to use the default known hosts, rather than + disabling known host checking. + +* Fixed a runtime warning about not awaiting a coroutine in unit tests. + +* Fixed a unit test failure on Windows when calling abort on a transport. + +* Fixed a problem where a "MAC verification failed" error was sometimes + sent on connection close. + +* Fixed SSHClientProcess code to not raise a runtime exception when + waiting more than once for a process to finish. Thanks go to GitHub + user starflows for reporting this issue. + +* Handled an error when attempting to import older verions of pyOpenSSL. + Thanks go to Maximilian Knespel for reporting this issue and testing + the fix. + +* Updated simple_server example code to switch from crypt to bcrypt, + since crypt has been removed in Python 3.13. Thanks go to Colin + Watson for providing this update. + +Release 2.17.0 (2 Sep 2024) +--------------------------- + +* Added support for specifying a per-connection credential store for GSSAPI + authentication. Thanks go to GitHub user zarganum for suggesting this + feature and proposing a detailed design. + +* Fixed a regression introduced in AsyncSSH 2.15.0 which could cause + connections to be closed with an uncaught exception when a session + on the connection was closed. Thanks go to Wilson Conley for being + the first to help reproduce this issue, and others who also helped + to confirm the fix. + +* Added a workaround where getaddrinfo() on some systems may return duplicate + entries, causing bind() to fail when opening a listener. Thanks go to + Colin Watson for reporting this issue and suggesting a fix. + +* Relaxed padding length check on OpenSSH private keys to provide better + compatibility with keys generated by PuTTYgen. + +* Improved documentation on SSHClient and SSHServer classes to explain + when they are created and their relationship to the SSHClientConnection + and SSHServerConnection classes. + +* Updated examples to use Python 3.7 and made some minor improvements. + +Release 2.16.0 (17 Aug 2024) +---------------------------- + +* Added client and server support for the OpenSSH "hostkeys" extension. + When using known_hosts, clients can provide a handler which will be + called with the changes between the keys currently trusted in the + client's known hosts and those available on the server. On the server + side, an application can choose whether or not to enable the sending + of this host key information. Thanks go to Matthijs Kooijman for + getting me to take another look at how this might be supported. + +* Related to the above, AsyncSSH now allows the configuration of multiple + server host keys of the same type when the send_server_host_keys option + is enabled. Only the first key of each type will be used in the SSH + handshake, but the others can appear in the list of supported host keys + for clients to begin trusting, allowing for smoother key rotation. + +* Fixed logging and typing issues in SFTP high-level copy functions. + A mix of bytes, str, and PurePath entries are now supported in places + where a list of file paths is allowed, and the type signatures have + been updated to reflect that the functions accept either a single + path or a list of paths. Thanks go to GitHub user eyalgolan1337 for + reporting these issues. + +* Improved typing on SFTP listdir() function. Thanks go to Tim Stumbaugh + for contributing this change. + +* Reworked the config file parser to improve on a previous fix related + to handling key/value pairs with an equals delimiter. + +* Improved handling of ciphers deprecated in cryptography 43.0.0. + Thanks go to Guillaume Mulocher for reporting this issue. + +* Improved support for use of Windows pathnames in ProxyCommand. + Thanks go to GitHub user chipolux for reporting this issue and + investigating the existing OpenSSH parsing behavior. + +Release 2.15.0 (3 Jul 2024) +--------------------------- + +* Added experimental support for tunneling of TUN/TAP network interfaces + on Linux and macOS, allowing for either automatic packet forwarding or + explicit reading and writing of packets sent through the tunnel by the + application. Both callback and stream APIs are available. + +* Added support for forwarding terminal size and terminal size changes + when stdin on an SSHServerProcess is redirected to a local TTY. + +* Added support for multiple tunnel/ProxyJump hosts. Thanks go to Adam + Martin for suggesting this enhancement and proposing a solution. + +* Added support for OpenSSH lsetstat SFTP extension to set attributes + on symbolic links on platforms which support that and use it to + improve symlink handling in the SFTP get, put, and copy methods. + In addition, a follow_symlinks option has been added on various + SFTPClient methods which get and set these attributes. Thanks go to + GitHub user eyalgolan1337 for reporting this issue. + +* Added support for password and passphrase arguments to be a callable + or awaitable, called when performing authentication or loading + encrypted private keys. Thanks go to GitHub user goblin for + suggesting this enhancement. + +* Added support for proper flow control when using AsyncFileWriter or + StreamWriter classes to do SSH process redirection. Thanks go to Benjy + Wiener for reporting this issue and providing feedback on the fix. + +* Added is_closed() method SSHClientConnection/SSHServerConnection to + return whether the associated network connection is closed or not. + +* Added support for setting and matching tags in OpenSSH config files. + +* Added an example of using "await" in addition to "async with" when + opening a new SSHClientConnection. Thanks go to Michael Davis for + suggesting this added documentation. + +* Improved handling CancelledError in SCP, avoiding an issue where + AsyncSSH could sometimes get stuck waiting for the channel to close. + Thanks go to Max Orlov for reporting the problem and providing code + to reproduce it. + +* Fixed a regression from 2.14.1 related to rekeying an SSH connection + when there's acitivty on the connection in the middle of rekeying. + Thanks go to GitHub user eyalgolan1337 for helping to narrow down + this problem and test the fix. + +* Fixed a problem with process redirection when a close is received + without a preceding EOF. Thanks go to GitHub user xuoguoto who helped + to provide sample scripts and ran tests to help track this down. + +* Fixed the processing of paths in SFTP client symlink requests. Thanks + go to André Glüpker for reporting the problem and providing test code + to demonstrate it. + +* Fixed an OpenSSH config file parsing issue. Thanks go to Siddh Raman + Pant for reporting this issue. + +* Worked around a bug in a user auth banner generated by the cryptlib + library. Thanks go to GitHub user mmayomoar for reporting this issue + and suggesting a fix. + +Release 2.14.2 (18 Dec 2023) +---------------------------- + +* Implemented "strict kex" support and other countermeasures to + protect against the Terrapin Attack described in `CVE-2023-48795 + `_. Thanks once + again go to Fabian Bäumer, Marcus Brinkmann, and Jörg Schwenk for + identifying and reporting this vulnerability and providing detailed + analysis and suggestions about proposed fixes. + +* Fixed config parser to properly an optional equals delimiter in all + config arguments. Thanks go to Fawaz Orabi for reporting this issue. + +* Fixed TCP send error handling to avoid race condition when receiving + incoming disconnect message. + +* Improved type signature in SSHConnection async context manager. Thanks + go to Pieter-Jan Briers for providing this. + +Release 2.14.1 (8 Nov 2023) +--------------------------- + +* Hardened AsyncSSH state machine against potential message + injection attacks, described in more detail in `CVE-2023-46445 + `_ and `CVE-2023-46446 + `_. Thanks go to + Fabian Bäumer, Marcus Brinkmann, and Jörg Schwenk for identifying + and reporting these vulnerabilities and providing detailed analysis + and suggestions about the proposed fixes. + +* Added support for passing in a regex in readuntil in SSHReader, + contributed by Oded Engel. + +* Added support for get_addresses() and get_port() methods on + SSHAcceptor. Thanks go to Allison Karlitskaya for suggesting + this feature. + +* Fixed an issue with AsyncFileWriter potentially writing data + out of order. Thanks go to Chan Chun Wai for reporting this + issue and providing code to reproduce it. + +* Updated testing to include Python 3.12. + +* Updated readthedocs integration to use YAML config file. + +Release 2.14.0 (30 Sep 2023) +---------------------------- + +* Added support for a new accept_handler argument when setting up + local port forwarding, allowing the client host and port to be + validated and/or logged for each new forwarded connection. An + accept handler can also be returned from the server_requested + function to provide this functionality when acting as a server. + Thanks go to GitHub user zgxkbtl for suggesting this feature. + +* Added an option to disable expensive RSA private key checks when + using OpenSSL 3.x. Functions that read private keys have been + modified to include a new unsafe_skip_rsa_key_validation argument + which can be used to avoid these additional checks, if you are + loading keys from a trusted source. + +* Added host information into AsyncSSH exceptions when host key + validation fails, and a few other improvements related to X.509 + certificate validation errors. Thanks go to Peter Moore for + suggesting this and providing an example. + +* Fixed a regression which prevented keys loaded into an SSH agent + with a certificate from working correctly beginning in AsyncSSH + after version 2.5.0. Thanks go to GitHub user htol for reporting + this issue and suggesting the commit which caused the problem. + +* Fixed an issue which was triggering an internal exception when + shutting down server sessions with the line editor enabled which + could cause some output to be lost on exit, especially when running + on Windows. Thanks go to GitHub user jerrbe for reporting this issue. + +* Fixed an issue in a unit test seen in Python 3.12 beta. Thanks go + to Georg Sauthoff for providing this fix. + +* Fixed a documentation error in SSHClientConnectionOptions and + SSHServerConnectionOptions. Thanks go to GitHub user bowenerchen + for reporting this issue. + +Release 2.13.2 (21 Jun 2023) +---------------------------- + +* Fixed an issue with host-based authentication when using proxy_command, + allowing it to be used if the caller explicitly specifies client_host. + Thanks go to GitHub user yuqingm7 for reporting this issue. + +* Improved handling of signature algorithms for OpenSSH certificates + so that RSA SHA-2 signatures will work with both older and newer + versions of OpenSSH. + +* Worked around an issue with some Cisco SSH implementations generating + invalid "ignore" packets. Thanks go to Jost Luebbe for reporting and + helping to debug this issue. + +* Fixed unit tests to avoid errors when cryptography's version of + OpenSSL disables support for SHA-1 signatures. + +* Fixed unit tests to avoid errors when the filesystem enforces that + filenames be valid UTF-8 strings. Thanks go to Robert Schütz and + Martin Weinelt for reporting this issue. + +* Added documentation about which config options apply when passing + a string as a tunnel argument. + +Release 2.13.1 (18 Feb 2023) +---------------------------- + +* Updated type definitions for mypy 1.0.0, removing a dependency on + implicit Optional types, and working around an issue that could + trigger a mypy internal error. + +* Updated unit tests to avoid calculation of SHA-1 signatures, which + are no longer allowed in cryptography 39.0.0. + +Release 2.13.0 (27 Dec 2022) +---------------------------- + +* Updated testing and coverage to drop Python 3.6 and add Python 3.11. + Thanks go to GitHub user hexchain for maintaining the GitHub workflows + supporting this! + +* Added new "recv_eof" option to not pass an EOF from a channel to a + redirected target, allowing output from multiple SSH sessions to be + sent and mixed with other direct output to that target. This is meant + to be similar to the existing "send_eof" option which controls whether + EOF on a redirect source is passed through to the SSH channel. Thanks + go to Stuart Reynolds for inspiring this idea. + +* Added new methods to make it easy to perform forwarding between TCP + ports and UNIX domain sockets. Thanks go to Alex Rogozhnikov for + suggesting this use case. + +* Added a workaround for a problem seen on a Huawei SFTP server where + it sends an invalid combination of file attribute flags. In cases where + the flags are otherwise valid and the right amount of attribute data is + available, AsyncSSH will ignore the invalid flags and proceed. + +* Fixed an issue with copying files to SFTP servers that don't support + random access I/O. The potential to trigger this failyre goes back + several releases, but a change in AsyncSSH 2.12 made out-of-order writes + much more likely. This fix returns AsyncSSH to its previous behavior + where out-of-order writes are unlikely even when taking advantage of + parallel reads. Thanks go to Patrik Lindgren and Stefan Walkner for + reporting this issue and helping to identify the source of the problem. + +* Fixed an issue when requesting remote port forwarding on a dynamically + allocated port. Thanks go to Daniel Shimon for reporting this and + proposing a fix. + +* Fixed an issue where readexactly could block indefinitely when a signal + is delivered in the stream before the requested number of bytes are + available. Thanks go to Artem Bezborodko for reporting this and + providing a fix. + +* Fixed an interoperability issue with OpenSSH when using SSH certificates + with RSA keys with a SHA-2 signature. Thanks go to Łukasz Siudut for + reporting this. + +* Fixed an issue with handling "None" in ProxyCommand, GlobalKnownHostsFile, + and UserKnownHostsFile config file options. Thanks go to GitHub user + dtrifiro for reporting this issue and suggesting a fix. + +Release 2.12.0 (10 Aug 2022) +---------------------------- + +* Added top-level functions run_client() and run_server() which allow + you to begin running an SSH client or server on an already-connected + socket. This capability is also available via a new "sock" argument + in the existing connect(), connect_reverse(), get_server_host_key(), + and get_server_auth_methods() functions. + +* Added "sock" argument to listen() and listen_reverse() functions + which takes an already-bound listening socket instead of a host + and port to bind a new socket to. + +* Added support for forwarding break, signal, and terminal size updates + when redirection of stdin is set up between two SSHProcess instances. + +* Added support for sntrup761x25519-sha512@openssh.com post-quantum + key exchange algorithm. For this to be available, the Open Quantum + Safe (liboqs) dynamic library must be installed. + +* Added "sig_alg" argument to set a signature algorithm when creating + OpenSSH certificates, allowing a choice between ssh-rsa, rsa-sha2-256, + and rsa-sha2-512 for certificates signed by RSA keys. + +* Added new read_parallel() method in SFTPClientFile which allows + parallel reads to be performed from a remote file, delivering + incremental results as these reads complete. Previously, large + reads would automatically be parallelized, but a result was only + returned after all reads completed. + +* Added definition of __all__ for public symbols in AsyncSSH to make + pyright autocompletion work better. Thanks go to Nicolas Riebesel + for providing this change. + +* Updated SFTP and SCP glob and copy functions to use scandir() instead + of listdir() to improve efficiency. + +* Updated default for "ignore_encrypted" client connection option to + ignore encrypted keys specified in an OpenSSH config file when no + passphrase is provided, similar to what was previously done for + keys with default names. + +* Fixed an issue when using an SSH agent with RSA keys and an X.509 + certificate while requesting SHA-2 signatures. + +* Fixed an issue with use of expanduser() in unit tests on newer versions + of Python. Thanks go to Georg Sauthoff for providing an initial version + of this fix. + +* Fixed an issue with fallback to a Pageant agent not working properly + on Windows when no agent_path or SSH_AUTH_SOCK was set. + +* Fixed improper escaping in readuntil(), causing certain punctuation in + separator to not match properly. Thanks go to Github user MazokuMaxy + for reporting this issue. + +* Fixed the connection close handler to properly mark channels as fully + closed when the peer unexpected closes the connection, allowing + exceptions to fire if an application continues to try and use + the channel. Thanks go to Taha Jahangir for reporting this issue and + suggesting a possible fix. + +* Eliminated unit testing against OpenSSH for tests involving DSA and + RSA keys using SHA-1 signatures, since this support is being dropped + in some distributions of OpenSSH. These tests are still performed, but + using only AsyncSSH code. Thanks go to Ken Dreyer and Georg Sauthoff + for reporting this issue and helping me to reproduce it. + +Release 2.11.0 (4 Jun 2022) +--------------------------- + +* Made a number of improvements in SFTP glob support, with thanks to + Github user LuckyDams for all the help working out these changes! + + * Added a new glob_sftpname() method which returns glob matches + together with attribute information, avoiding the need for a + caller to make separate calls to stat() on the returned results. + * Switched from listdir() to scandir() to reduce the number of + stat() operations required while finding matches. + * Added code to remove duplicates when glob() is called with + multiple patterns that match the same path. + * Added a cache of directory listing and stat results to improve + performance when matching patterns with overlapping paths. + * Fixed an "index out of range" bug in recursive glob matching + and aligned it better with results reeturned by UNIX shells. + * Changed matching to ignore inaccessible or non-existent paths + in a glob pattern, to allow accessible paths to be fully + explored before returning an error. The error handler will now + be called only if a pattern results in no matches, or if a more + serious error occurs while scanning. + +* Changed SFTP makedirs() method to work better in cases where parts of + requested path already exist but don't allow read access. As long as + the entire path can be created, makedirs() will succeed, even if some + directories on the path don't allow their contents to be read. Thanks + go to Peter Rowlands for providing this fix. + +* Replaced custom Diffie Hellman implementation in AsyncSSH with the + one in the cryptography package, resulting in an over 10x speedup. + Thanks go to Github user iwanb for suggesting this change. + +* Fixed AsyncSSH to re-acquire GSS credentials when performing key + renegotiation to avoid expired credentials on long-lived connections. + Thanks go to Github user PromyLOPh for pointing out this issue and + suggesting a fix. + +* Fixed GSS MIC to work properly with GSS key exchange when AsyncSSH + is running as a server. This was previously fixed on the client side, + but a similar fix for the server was missed. + +* Changed connection timeout unit tests to work better in environments + where a firewall is present. Thanks go to Stefano Rivera for + reporting this issue. + +* Improved unit tests of Windows SSPI GSSAPI module. + +* Improved speed of unit tests by reducing the number of key generation + calls. RSA key generation in particular has gotten much more expensive + in OpenSSL 3. + Release 2.10.1 (16 Apr 2022) ---------------------------- +* Added a workaround for a bug in dropbear which can improperly reject + full-sized data packets when compression is enabled. Thanks go to + Matti Niemenmaa for reporting this issue and helping to reproduce it. + * Added support for "Match Exec" in config files and updated AsyncSSH API calls to do config parsing in an executor to avoid blocking the event loop if a "Match Exec" command doesn't return immediately. @@ -20,7 +629,7 @@ Release 2.10.1 (16 Apr 2022) trigger a debug message rather than an error. Thanks go to Caleb Ho for reporting this issue! -* Update minimum required version of cryprography package to 3.1, to +* Updated minimum required version of cryprography package to 3.1, to allow calls to it to be made without passing in a "backend" argument. This was missed back in the 2.9 release. Thanks go to Github users sebby97 and JavaScriptDude for reporting this issue! @@ -161,7 +770,7 @@ Release 2.7.1 (6 Sep 2021) * Fixed a couple of issues related to sending SSH_EXT_INFO messages. * Fixed an issue with using SSHAcceptor as an async context manager. - Thanks go to Paulo Costa for reporing this. + Thanks go to Paulo Costa for reporting this. * Fixed an issue where a tunnel wasn't always cleaned up properly when creating a remote listener. @@ -281,7 +890,7 @@ Release 2.6.0 (1 May 2021) 0.8.1. * Fixed problem with setting config options with percent substitutions - to 'none'. Percent subsitution should not be performed in this case. + to 'none'. Percent substitution should not be performed in this case. Thanks go to Yuqing Miao for finding and reporting this issue! * Fixed return type of filenames in SFTPClient scandir() and readlink() @@ -774,7 +1383,7 @@ Release 1.16.1 (30 Mar 2019) * Added channel, connection, and env properties to SFTPServer instances, so connection and channel information can be used to influence the SFTP server's behavior. Previously, connection information was made - avaiable through the constructor, but channel and environment + available through the constructor, but channel and environment information was not. Now, all of these are available as properties on the SFTPServer instance without the need to explicitly store anything in a custom constructor. @@ -1137,7 +1746,7 @@ Release 1.12.0 (5 Feb 2018) * Updated key and certificate comment handling to be less sensitive to the encoding of non-ASCII characters. The get_comment() and set_comment() - functions now take an optional encoding paramter, defaulting to UTF-8 + functions now take an optional encoding parameter, defaulting to UTF-8 but allowing for others encodings. There's also a get_comment_bytes() function to get the comment data as bytes without performing Unicode decoding. @@ -1386,7 +1995,7 @@ Release 1.7.0 (7 Oct 2016) these signature algorithms. * Added new load_keypairs and load_public_keys API functions which - support expicitly loading keys using the same syntax that was + support explicitly loading keys using the same syntax that was previously available for specifying client_keys, authorized_client_keys, and server_host_keys arguments when creating SSH clients and servers. diff --git a/docs/conf.py b/docs/conf.py index f8ce409..f36e3e8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -44,7 +44,7 @@ # General information about the project. project = 'AsyncSSH' -copyright = '2013-2017, ' + __author__ +copyright = '2013-2023, ' + __author__ # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the @@ -103,7 +103,7 @@ # further. For a list of options available for each theme, see the # documentation. html_theme_options = { - "sidebarwidth": 305, + "sidebarwidth": 450, "stickysidebar": "true" } @@ -129,7 +129,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +#html_static_path = ['_static'] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. diff --git a/docs/index.rst b/docs/index.rst index c931303..b05d581 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,7 +29,26 @@ key doesn't match. :literal: :start-line: 22 -This example only uses the output on stdout, but output on stderr is also +This example shows using the :class:`SSHClientConnection` returned by +:func:`connect()` as a context manager, so that the connection is +automatically closed when the end of the code block which opened it is +reached. However, if you need the connection object to live longer, you +can use "await" instead of "async with": + + .. code:: + + conn = await asyncssh.connect('localhost') + +In this case, the application will need to close the connection explicitly +when done with it, and it is best to also wait for the close to complete. +This can be done with the following code from inside an async function: + + .. code:: + + conn.close() + await conn.wait_closed() + +Only stdout is referenced this example, but output on stderr is also collected as another attribute in the returned :class:`SSHCompletedProcess` object. @@ -38,7 +57,7 @@ write calls operate on strings by default. If you want to send and receive binary data, you can set the encoding to `None` when the session is opened to make read and write operate on bytes instead. Alternate encodings can also be selected to change how strings are -convered to and from bytes. +converted to and from bytes. To check against a different set of server host keys, they can be provided in the known_hosts argument when the connection is opened: @@ -435,7 +454,7 @@ write calls operate on strings by default. If you want to send and receive binary data, you can set the encoding to `None` when the session is opened to make read and write operate on bytes instead. Alternate encodings can also be selected to change how strings are -convered to and from bytes. +converted to and from bytes. .. include:: ../examples/simple_server.py :literal: @@ -512,7 +531,7 @@ provide character echo and line editing. To better support interactive applications like the one above, AsyncSSH defaults to providing basic line editing for server sessions which request a pseudo-terminal. -When thise line editor is enabled, it defaults to delivering input to +When this line editor is enabled, it defaults to delivering input to the application a line at a time. Applications can switch between line and character at a time input using the :meth:`set_line_mode() ` method. Also, when in line diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000..3351a64 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,2 @@ +cryptography >= 39.0 +typing_extensions >= 3.6 diff --git a/examples/callback_client.py b/examples/callback_client.py index 713e78d..c9069cf 100755 --- a/examples/callback_client.py +++ b/examples/callback_client.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -33,7 +33,7 @@ def connection_lost(self, exc: Optional[Exception]) -> None: class MySSHClient(asyncssh.SSHClient): def connection_made(self, conn: asyncssh.SSHClientConnection) -> None: - print('Connection made to %s.' % conn.get_extra_info('peername')[0]) + print(f'Connection made to {conn.get_extra_info('peername')[0]}.') def auth_completed(self) -> None: print('Authentication successful.') @@ -46,6 +46,6 @@ async def run_client() -> None: await chan.wait_closed() try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/callback_client2.py b/examples/callback_client2.py index e2db556..7a6978c 100755 --- a/examples/callback_client2.py +++ b/examples/callback_client2.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -37,6 +37,6 @@ async def run_client() -> None: await chan.wait_closed() try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/callback_client3.py b/examples/callback_client3.py index fca0c33..fcf4f3f 100755 --- a/examples/callback_client3.py +++ b/examples/callback_client3.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -40,6 +40,6 @@ async def run_client() -> None: await chan.wait_closed() try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/callback_math_server.py b/examples/callback_math_server.py index 8d3994d..5599635 100755 --- a/examples/callback_math_server.py +++ b/examples/callback_math_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -52,18 +52,21 @@ def data_received(self, data: str, datatype: asyncssh.DataType) -> None: if line: self._total += int(line) except ValueError: - self._chan.write_stderr('Invalid number: %s\n' % line) + self._chan.write_stderr(f'Invalid number: {line}\n') self._input = lines[-1] def eof_received(self) -> bool: - self._chan.write('Total = %s\n' % self._total) + self._chan.write(f'Total = {self._total}\n') self._chan.exit(0) return False def break_received(self, msec: int) -> bool: return self.eof_received() + def soft_eof_received(self) -> None: + self.eof_received() + class MySSHServer(asyncssh.SSHServer): def session_requested(self) -> asyncssh.SSHServerSession: return MySSHServerSession() @@ -73,7 +76,7 @@ async def start_server() -> None: server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/chat_server.py b/examples/chat_server.py index 888a959..275d3ec 100755 --- a/examples/chat_server.py +++ b/examples/chat_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2016-2021 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -57,18 +57,18 @@ async def run(self) -> None: self.write('Enter your name: ') name = (await self.readline()).rstrip('\n') - self.write('\n%d other users are connected.\n\n' % len(self._clients)) + self.write(f'\n{len(self._clients)} other users are connected.\n\n') self._clients.append(self) - self.broadcast('*** %s has entered chat ***\n' % name) + self.broadcast(f'*** {name} has entered chat ***\n') try: async for line in self._process.stdin: - self.broadcast('%s: %s' % (name, line)) + self.broadcast(f'{name}: {line}') except asyncssh.BreakReceived: pass - self.broadcast('*** %s has left chat ***\n' % name) + self.broadcast(f'*** {name} has left chat ***\n') self._clients.remove(self) async def start_server() -> None: @@ -76,7 +76,7 @@ async def start_server() -> None: authorized_client_keys='ssh_user_ca', process_factory=ChatClient.handle_client) -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/check_exit_status.py b/examples/check_exit_status.py index c421e4d..7a96e69 100755 --- a/examples/check_exit_status.py +++ b/examples/check_exit_status.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -34,6 +34,6 @@ async def run_client() -> None: file=sys.stderr) try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/chroot_sftp_server.py b/examples/chroot_sftp_server.py index 28f9b4b..8fd5509 100755 --- a/examples/chroot_sftp_server.py +++ b/examples/chroot_sftp_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2016-2021 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -40,7 +40,7 @@ async def start_server() -> None: authorized_client_keys='ssh_user_ca', sftp_factory=MySFTPServer) -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/direct_client.py b/examples/direct_client.py index 8d9fd43..5b5a383 100755 --- a/examples/direct_client.py +++ b/examples/direct_client.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -44,6 +44,6 @@ async def run_client() -> None: await chan.wait_closed() try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/direct_server.py b/examples/direct_server.py index abd7f6f..26d21fa 100755 --- a/examples/direct_server.py +++ b/examples/direct_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -52,7 +52,7 @@ async def start_server() -> None: server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/editor.py b/examples/editor.py index d57caed..baa3971 100755 --- a/examples/editor.py +++ b/examples/editor.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -33,8 +33,8 @@ async def handle_client(process: asyncssh.SSHServerProcess): channel = cast(asyncssh.SSHLineEditorChannel, process.channel) - process.stdout.write('Welcome to my SSH server, %s!\n\n' % - process.get_extra_info('username')) + username = process.get_extra_info('username') + process.stdout.write(f'Welcome to my SSH server, {username}!\n\n') channel.set_echo(False) process.stdout.write('Tell me a secret: ') @@ -53,7 +53,7 @@ async def start_server() -> None: authorized_client_keys='ssh_user_ca', process_factory=handle_client) -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/gather_results.py b/examples/gather_results.py index a0c3d78..958dae4 100755 --- a/examples/gather_results.py +++ b/examples/gather_results.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2016-2021 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -35,14 +35,14 @@ async def run_multiple_clients() -> None: for i, result in enumerate(results, 1): if isinstance(result, Exception): - print('Task %d failed: %s' % (i, str(result))) + print(f'Task {i} failed: {result}') elif result.exit_status != 0: - print('Task %d exited with status %s:' % (i, result.exit_status)) + print(f'Task {i} exited with status {result.exit_status}:') print(result.stderr, end='') else: - print('Task %d succeeded:' % i) + print(f'Task {i} succeeded:') print(result.stdout, end='') print(75*'-') -asyncio.get_event_loop().run_until_complete(run_multiple_clients()) +asyncio.run(run_multiple_clients()) diff --git a/examples/listening_client.py b/examples/listening_client.py index 2a85ae6..68e9523 100755 --- a/examples/listening_client.py +++ b/examples/listening_client.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -31,7 +31,7 @@ def data_received(self, data: bytes, datatype: asyncssh.DataType): def connection_requested(orig_host: str, orig_port: int) -> asyncssh.SSHTCPSession: - print('Connection received from %s, port %s' % (orig_host, orig_port)) + print(f'Connection received from {orig_host}, port {orig_port}') return MySSHTCPSession() async def run_client() -> None: @@ -45,6 +45,6 @@ async def run_client() -> None: print('Listener couldn\'t be opened.', file=sys.stderr) try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/local_forwarding_client.py b/examples/local_forwarding_client.py index 0153d69..1670b77 100755 --- a/examples/local_forwarding_client.py +++ b/examples/local_forwarding_client.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -28,6 +28,6 @@ async def run_client() -> None: await listener.wait_closed() try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/local_forwarding_client2.py b/examples/local_forwarding_client2.py index ca7996f..b28e7c0 100755 --- a/examples/local_forwarding_client2.py +++ b/examples/local_forwarding_client2.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -25,10 +25,10 @@ async def run_client() -> None: async with asyncssh.connect('localhost') as conn: listener = await conn.forward_local_port('', 0, 'www.google.com', 80) - print('Listening on port %s...' % listener.get_port()) + print(f'Listening on port {listener.get_port()}...') await listener.wait_closed() try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/local_forwarding_server.py b/examples/local_forwarding_server.py index 44cca5e..f4e4983 100755 --- a/examples/local_forwarding_server.py +++ b/examples/local_forwarding_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -44,7 +44,7 @@ async def start_server() -> None: server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/math_client.py b/examples/math_client.py index 64ee085..eeee7bb 100755 --- a/examples/math_client.py +++ b/examples/math_client.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2016-2021 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -31,6 +31,6 @@ async def run_client() -> None: print(op, '=', result, end='') try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/math_server.py b/examples/math_server.py index 301fadc..9f118cf 100755 --- a/examples/math_server.py +++ b/examples/math_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -41,11 +41,11 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None: try: total += int(line) except ValueError: - process.stderr.write('Invalid number: %s\n' % line) + process.stderr.write(f'Invalid number: {line}\n') except asyncssh.BreakReceived: pass - process.stdout.write('Total = %s\n' % total) + process.stdout.write(f'Total = {total}\n') process.exit(0) async def start_server() -> None: @@ -53,7 +53,7 @@ async def start_server() -> None: authorized_client_keys='ssh_user_ca', process_factory=handle_client) -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/redirect_input.py b/examples/redirect_input.py index 0316d81..ba8f5be 100755 --- a/examples/redirect_input.py +++ b/examples/redirect_input.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -27,6 +27,6 @@ async def run_client() -> None: await conn.run('tail -r', input='1\n2\n3\n', stdout='/tmp/stdout') try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/redirect_local_pipe.py b/examples/redirect_local_pipe.py index 25d619b..7e90984 100755 --- a/examples/redirect_local_pipe.py +++ b/examples/redirect_local_pipe.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -30,6 +30,6 @@ async def run_client() -> None: print(remote_result.stdout, end='') try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/redirect_remote_pipe.py b/examples/redirect_remote_pipe.py index 4c6f0e8..1a8d2dd 100755 --- a/examples/redirect_remote_pipe.py +++ b/examples/redirect_remote_pipe.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -29,6 +29,6 @@ async def run_client() -> None: print(proc2_result.stdout, end='') try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/redirect_server.py b/examples/redirect_server.py index ae98af6..5ee6a36 100755 --- a/examples/redirect_server.py +++ b/examples/redirect_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2017-2021 by Ron Frederick and others. +# Copyright (c) 2017-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -43,7 +43,7 @@ async def start_server() -> None: authorized_client_keys='ssh_user_ca', process_factory=handle_client) -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/remote_forwarding_client.py b/examples/remote_forwarding_client.py index da6b33a..0ff1f16 100755 --- a/examples/remote_forwarding_client.py +++ b/examples/remote_forwarding_client.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -28,6 +28,6 @@ async def run_client() -> None: await listener.wait_closed() try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/remote_forwarding_client2.py b/examples/remote_forwarding_client2.py index 44b0bd8..5733e05 100755 --- a/examples/remote_forwarding_client2.py +++ b/examples/remote_forwarding_client2.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -24,7 +24,7 @@ from functools import partial from typing import Awaitable -def connection_requested(conn: asyncssh.SSHServerConnection, orig_host: str, +def connection_requested(conn: asyncssh.SSHClientConnection, orig_host: str, orig_port: int) -> Awaitable[asyncssh.SSHForwarder]: if orig_host in ('127.0.0.1', '::1'): return conn.forward_connection('localhost', 80) @@ -40,6 +40,6 @@ async def run_client() -> None: await listener.wait_closed() try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/remote_forwarding_server.py b/examples/remote_forwarding_server.py index 75f3479..5b35bda 100755 --- a/examples/remote_forwarding_server.py +++ b/examples/remote_forwarding_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -38,7 +38,7 @@ async def start_server() -> None: server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/reverse_client.py b/examples/reverse_client.py index 905c42b..0a5375b 100755 --- a/examples/reverse_client.py +++ b/examples/reverse_client.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.8 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -58,6 +58,6 @@ async def run_reverse_client() -> None: await conn.wait_closed() try: - asyncio.get_event_loop().run_until_complete(run_reverse_client()) + asyncio.run(run_reverse_client()) except (OSError, asyncssh.Error) as exc: sys.exit('Reverse SSH connection failed: ' + str(exc)) diff --git a/examples/reverse_server.py b/examples/reverse_server.py index 90c67e7..cea4c46 100755 --- a/examples/reverse_server.py +++ b/examples/reverse_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.8 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -57,7 +57,7 @@ async def start_reverse_server() -> None: known_hosts='trusted_client_host_keys', acceptor=run_commands) -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_reverse_server()) diff --git a/examples/scp_client.py b/examples/scp_client.py index 11659f1..5c0581e 100755 --- a/examples/scp_client.py +++ b/examples/scp_client.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2017-2021 by Ron Frederick and others. +# Copyright (c) 2017-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -26,6 +26,6 @@ async def run_client() -> None: await asyncssh.scp('localhost:example.txt', '.') try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SFTP operation failed: ' + str(exc)) diff --git a/examples/set_environment.py b/examples/set_environment.py index dc1048c..b9827be 100755 --- a/examples/set_environment.py +++ b/examples/set_environment.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -29,6 +29,6 @@ async def run_client() -> None: print(result.stdout, end='') try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/set_terminal.py b/examples/set_terminal.py index 18a057d..ec541bb 100755 --- a/examples/set_terminal.py +++ b/examples/set_terminal.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -30,6 +30,6 @@ async def run_client() -> None: print(result.stdout, end='') try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/sftp_client.py b/examples/sftp_client.py index 49d804b..470da40 100755 --- a/examples/sftp_client.py +++ b/examples/sftp_client.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2015-2021 by Ron Frederick and others. +# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -28,6 +28,6 @@ async def run_client() -> None: await sftp.get('example.txt') try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SFTP operation failed: ' + str(exc)) diff --git a/examples/show_environment.py b/examples/show_environment.py index f3c50b5..46e9b04 100755 --- a/examples/show_environment.py +++ b/examples/show_environment.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -34,7 +34,7 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None: keywidth = max(map(len, process.env.keys()))+1 process.stdout.write('Environment:\n') for key, value in process.env.items(): - process.stdout.write(' %-*s %s\n' % (keywidth, key+':', value)) + process.stdout.write(f' {key+":":{keywidth}} {value}\n') process.exit(0) else: process.stderr.write('No environment sent.\n') @@ -45,7 +45,7 @@ async def start_server() -> None: authorized_client_keys='ssh_user_ca', process_factory=handle_client) -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/show_terminal.py b/examples/show_terminal.py index 1cbfc02..9d6ade3 100755 --- a/examples/show_terminal.py +++ b/examples/show_terminal.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -32,21 +32,20 @@ async def handle_client(process: asyncssh.SSHServerProcess) -> None: width, height, pixwidth, pixheight = process.term_size - process.stdout.write('Terminal type: %s, size: %sx%s' % - (process.term_type, width, height)) + process.stdout.write(f'Terminal type: {process.term_type}, ' + f'size: {width}x{height}') if pixwidth and pixheight: - process.stdout.write(' (%sx%s pixels)' % (pixwidth, pixheight)) + process.stdout.write(f' ({pixwidth}x{pixheight} pixels)') process.stdout.write('\nTry resizing your window!\n') while not process.stdin.at_eof(): try: await process.stdin.read() except asyncssh.TerminalSizeChanged as exc: - process.stdout.write('New window size: %sx%s' % - (exc.width, exc.height)) + process.stdout.write(f'New window size: {exc.width}x{exc.height}') if exc.pixwidth and exc.pixheight: - process.stdout.write(' (%sx%s pixels)' % - (exc.pixwidth, exc.pixheight)) + process.stdout.write(f' ({exc.pixwidth}' + f'x{exc.pixheight} pixels)') process.stdout.write('\n') async def start_server() -> None: @@ -54,7 +53,7 @@ async def start_server() -> None: authorized_client_keys='ssh_user_ca', process_factory=handle_client) -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/simple_cert_server.py b/examples/simple_cert_server.py index a13a21b..78e5a6f 100755 --- a/examples/simple_cert_server.py +++ b/examples/simple_cert_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -30,8 +30,8 @@ import asyncio, asyncssh, sys def handle_client(process: asyncssh.SSHServerProcess) -> None: - process.stdout.write('Welcome to my SSH server, %s!\n' % - process.get_extra_info('username')) + username = process.get_extra_info('username') + process.stdout.write(f'Welcome to my SSH server, {username}!\n') process.exit(0) async def start_server() -> None: @@ -39,7 +39,7 @@ async def start_server() -> None: authorized_client_keys='ssh_user_ca', process_factory=handle_client) -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/simple_client.py b/examples/simple_client.py index 0bb5365..71a8d06 100755 --- a/examples/simple_client.py +++ b/examples/simple_client.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -24,10 +24,16 @@ async def run_client() -> None: async with asyncssh.connect('localhost') as conn: - result = await conn.run('ls abc', check=True) - print(result.stdout, end='') + try: + result = await conn.run('ls abc', check=True) + except asyncssh.ProcessError as exc: + print(exc.stderr, end='') + print(f'Process exited with status {exc.exit_status}', + file=sys.stderr) + else: + print(result.stdout, end='') try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/simple_keyed_server.py b/examples/simple_keyed_server.py index 22b4c3b..57d81bc 100755 --- a/examples/simple_keyed_server.py +++ b/examples/simple_keyed_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -31,8 +31,8 @@ import asyncio, asyncssh, sys def handle_client(process: asyncssh.SSHServerProcess) -> None: - process.stdout.write('Welcome to my SSH server, %s!\n' % - process.get_extra_info('username')) + username = process.get_extra_info('username') + process.stdout.write(f'Welcome to my SSH server, {username}!\n') process.exit(0) class MySSHServer(asyncssh.SSHServer): @@ -41,8 +41,8 @@ def connection_made(self, conn: asyncssh.SSHServerConnection) -> None: def begin_auth(self, username: str) -> bool: try: - self._conn.set_authorized_keys('authorized_keys/%s' % username) - except IOError: + self._conn.set_authorized_keys(f'authorized_keys/{username}') + except OSError: pass return True @@ -52,7 +52,7 @@ async def start_server() -> None: server_host_keys=['ssh_host_key'], process_factory=handle_client) -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/simple_scp_server.py b/examples/simple_scp_server.py index 402cfb0..117c8d5 100755 --- a/examples/simple_scp_server.py +++ b/examples/simple_scp_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2015-2021 by Ron Frederick and others. +# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -34,7 +34,7 @@ async def start_server() -> None: authorized_client_keys='ssh_user_ca', sftp_factory=True, allow_scp=True) -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/simple_server.py b/examples/simple_server.py index d5da02a..9bbdc7c 100755 --- a/examples/simple_server.py +++ b/examples/simple_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -24,22 +24,22 @@ # private key in it to use as a server host key. An SSH host certificate # can optionally be provided in the file ``ssh_host_key-cert.pub``. -import asyncio, asyncssh, crypt, sys +import asyncio, asyncssh, bcrypt, sys from typing import Optional -passwords = {'guest': '', # guest account with no password - 'user123': 'qV2iEadIGV2rw' # password of 'secretpw' +passwords = {'guest': b'', # guest account with no password + 'user123': bcrypt.hashpw(b'secretpw', bcrypt.gensalt()), } def handle_client(process: asyncssh.SSHServerProcess) -> None: - process.stdout.write('Welcome to my SSH server, %s!\n' % - process.get_extra_info('username')) + username = process.get_extra_info('username') + process.stdout.write(f'Welcome to my SSH server, {username}!\n') process.exit(0) class MySSHServer(asyncssh.SSHServer): def connection_made(self, conn: asyncssh.SSHServerConnection) -> None: - print('SSH connection received from %s.' % - conn.get_extra_info('peername')[0]) + peername = conn.get_extra_info('peername')[0] + print(f'SSH connection received from {peername}.') def connection_lost(self, exc: Optional[Exception]) -> None: if exc: @@ -49,21 +49,25 @@ def connection_lost(self, exc: Optional[Exception]) -> None: def begin_auth(self, username: str) -> bool: # If the user's password is the empty string, no auth is required - return passwords.get(username) != '' + return passwords.get(username) != b'' def password_auth_supported(self) -> bool: return True def validate_password(self, username: str, password: str) -> bool: - pw = passwords.get(username, '*') - return crypt.crypt(password, pw) == pw + if username not in passwords: + return False + pw = passwords[username] + if not password and not pw: + return True + return bcrypt.checkpw(password.encode('utf-8'), pw) async def start_server() -> None: await asyncssh.create_server(MySSHServer, '', 8022, server_host_keys=['ssh_host_key'], process_factory=handle_client) -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/simple_sftp_server.py b/examples/simple_sftp_server.py index 12d1856..ce30102 100755 --- a/examples/simple_sftp_server.py +++ b/examples/simple_sftp_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2015-2021 by Ron Frederick and others. +# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -34,7 +34,7 @@ async def start_server() -> None: authorized_client_keys='ssh_user_ca', sftp_factory=True) -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/stream_direct_client.py b/examples/stream_direct_client.py index 8c796a7..6155115 100755 --- a/examples/stream_direct_client.py +++ b/examples/stream_direct_client.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -35,6 +35,6 @@ async def run_client() -> None: sys.stdout.buffer.write(response) try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/examples/stream_direct_server.py b/examples/stream_direct_server.py index a691e27..dff07c6 100755 --- a/examples/stream_direct_server.py +++ b/examples/stream_direct_server.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -57,7 +57,7 @@ async def start_server() -> None: server_host_keys=['ssh_host_key'], authorized_client_keys='ssh_user_ca') -loop = asyncio.get_event_loop() +loop = asyncio.new_event_loop() try: loop.run_until_complete(start_server()) diff --git a/examples/stream_listening_client.py b/examples/stream_listening_client.py index 828a785..3b3266d 100755 --- a/examples/stream_listening_client.py +++ b/examples/stream_listening_client.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.7 # -# Copyright (c) 2013-2021 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -30,7 +30,7 @@ async def handle_connection(reader, writer): writer.close() def connection_requested(orig_host, orig_port): - print('Connection received from %s, port %s' % (orig_host, orig_port)) + print(f'Connection received from {orig_host}, port {orig_port}') return handle_connection async def run_client(): @@ -40,6 +40,6 @@ async def run_client(): await server.wait_closed() try: - asyncio.get_event_loop().run_until_complete(run_client()) + asyncio.run(run_client()) except (OSError, asyncssh.Error) as exc: sys.exit('SSH connection failed: ' + str(exc)) diff --git a/mypy.ini b/mypy.ini index 068f091..004d747 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,2 @@ [mypy] allow_redefinition = True -ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ea30886 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,59 @@ +[build-system] +requires = ['setuptools'] +build-backend = 'setuptools.build_meta' + +[project] +name = 'asyncssh' +license = {text = 'EPL-2.0 OR GPL-2.0-or-later'} +description = 'AsyncSSH: Asynchronous SSHv2 client and server library' +readme = 'README.rst' +authors = [{name = 'Ron Frederick', email = 'ronf@timeheart.net'}] +classifiers = [ + 'Development Status :: 5 - Production/Stable', + 'Environment :: Console', + 'Intended Audience :: Developers', + 'License :: OSI Approved', + 'Operating System :: MacOS :: MacOS X', + 'Operating System :: POSIX', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', + 'Topic :: Internet', + 'Topic :: Security :: Cryptography', + 'Topic :: Software Development :: Libraries :: Python Modules', + 'Topic :: System :: Networking', +] +requires-python = '>= 3.6' +dependencies = [ + 'cryptography >= 39.0', + 'typing_extensions >= 4.0.0', +] +dynamic = ['version'] + +[project.optional-dependencies] +bcrypt = ['bcrypt >= 3.1.3'] +fido2 = ['fido2 >= 0.9.2, < 2'] +gssapi = ['gssapi >= 1.2.0'] +libnacl = ['libnacl >= 1.4.2'] +pkcs11 = ['python-pkcs11 >= 0.7.0'] +pyOpenSSL = ['pyOpenSSL >= 23.0.0'] +pywin32 = ['pywin32 >= 227'] + + +[project.urls] +Homepage = 'http://asyncssh.timeheart.net' +Documentation = 'https://asyncssh.readthedocs.io' +Source = 'https://github.com/ronf/asyncssh' +Tracker = 'https://github.com/ronf/asyncssh/issues' + +[tool.setuptools.dynamic] +version = {attr = 'asyncssh.version.__version__'} + +[tool.setuptools.packages.find] +include = ['asyncssh*'] + +[tool.setuptools.package-data] +asyncssh = ['py.typed'] diff --git a/setup.py b/setup.py deleted file mode 100755 index b9df796..0000000 --- a/setup.py +++ /dev/null @@ -1,89 +0,0 @@ -#!/usr/bin/env python3.6 - -# Copyright (c) 2013-2022 by Ron Frederick and others. -# -# This program and the accompanying materials are made available under -# the terms of the Eclipse Public License v2.0 which accompanies this -# distribution and is available at: -# -# http://www.eclipse.org/legal/epl-2.0/ -# -# This program may also be made available under the following secondary -# licenses when the conditions for such availability set forth in the -# Eclipse Public License v2.0 are satisfied: -# -# GNU General Public License, Version 2.0, or any later versions of -# that license -# -# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later -# -# Contributors: -# Ron Frederick - initial implementation, API, and documentation - -"""AsyncSSH: Asynchronous SSHv2 client and server library - -AsyncSSH is a Python package which provides an asynchronous client and -server implementation of the SSHv2 protocol on top of the Python asyncio -framework. It requires Python 3.6 or later and the PyCA library for some -cryptographic functions. - -""" - -from os import path -from setuptools import setup - -base_dir = path.abspath(path.dirname(__file__)) - -doclines = __doc__.split('\n', 1) - -with open(path.join(base_dir, 'README.rst')) as desc: - long_description = desc.read() - -with open(path.join(base_dir, 'asyncssh', 'version.py')) as version: - exec(version.read()) - -setup(name = 'asyncssh', - version = __version__, - author = __author__, - author_email = __author_email__, - url = __url__, - project_urls = { - 'Documentation': 'https://asyncssh.readthedocs.io', - 'Source': 'https://github.com/ronf/asyncssh', - 'Tracker': 'https://github.com/ronf/asyncssh/issues' - }, - license = 'Eclipse Public License v2.0', - description = doclines[0], - long_description = long_description, - platforms = 'Any', - python_requires = '>= 3.6', - install_requires = ['cryptography >= 3.1', 'typing_extensions >= 3.6'], - extras_require = { - 'bcrypt': ['bcrypt >= 3.1.3'], - 'fido2': ['fido2 >= 0.9.2'], - 'gssapi': ['gssapi >= 1.2.0'], - 'libnacl': ['libnacl >= 1.4.2'], - 'pkcs11': ['python-pkcs11 >= 0.7.0'], - 'pyOpenSSL': ['pyOpenSSL >= 17.0.0'], - 'pywin32': ['pywin32 >= 227'] - }, - packages = ['asyncssh', 'asyncssh.crypto'], - package_data = {'asyncssh': ['py.typed']}, - scripts = [], - test_suite = 'tests', - classifiers = [ - 'Development Status :: 5 - Production/Stable', - 'Environment :: Console', - 'Intended Audience :: Developers', - 'License :: OSI Approved', - 'Operating System :: MacOS :: MacOS X', - 'Operating System :: POSIX', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Topic :: Internet', - 'Topic :: Security :: Cryptography', - 'Topic :: Software Development :: Libraries :: Python Modules', - 'Topic :: System :: Networking']) diff --git a/tests/gssapi_stub.py b/tests/gssapi_stub.py index 4f4a185..586f781 100644 --- a/tests/gssapi_stub.py +++ b/tests/gssapi_stub.py @@ -20,6 +20,8 @@ """Stub GSSAPI module for unit tests""" +from enum import IntEnum + from asyncssh.gss import GSSError from .gss_stub import step @@ -38,7 +40,9 @@ def __init__(self, base, _name_type=None): class Credentials: """Stub class for GSS credentials""" - def __init__(self, name=None, usage=None): + def __init__(self, name=None, usage=None, store=None): + # pylint: disable=unused-argument + self.host = name.host if name else '' self.server = usage == 'accept' @@ -52,12 +56,14 @@ def mechs(self): return [2] -class RequirementFlag: +class RequirementFlag(IntEnum): """Stub class for GSS requirement flags""" - mutual_authentication = 'mutual_auth' - integrity = 'integrity' - delegate_to_peer = 'delegate' + # pylint: disable=invalid-name + + delegate_to_peer = 1 + mutual_authentication = 2 + integrity = 4 class SecurityContext: @@ -67,12 +73,12 @@ def __init__(self, name=None, creds=None, flags=None): host = creds.host if creds.server else name.host if flags is None: - flags = set((RequirementFlag.mutual_authentication, - RequirementFlag.integrity)) + flags = RequirementFlag.mutual_authentication | \ + RequirementFlag.integrity if ((creds.server and 'no_server_integrity' in host) or (not creds.server and 'no_client_integrity' in host)): - flags.remove(RequirementFlag.integrity) + flags &= ~RequirementFlag.integrity self._host = host self._server = creds.server @@ -121,7 +127,10 @@ def step(self, token=None): def get_signature(self, _data): """Sign a block of data""" - return b'fail' if 'fail' in self._host else 'succeed' + if 'sign_error' in self._host: + raise GSSError(99, 99) + + return b'fail' if 'verify_error' in self._host else b'' def verify_signature(self, _data, sig): """Verify a signature for a block of data""" diff --git a/tests/keysign_stub.py b/tests/keysign_stub.py index 3848407..db10add 100644 --- a/tests/keysign_stub.py +++ b/tests/keysign_stub.py @@ -48,7 +48,7 @@ async def communicate(self, request): elif version == 1: return b'', b'invalid request' else: - skey = asyncssh.load_keypairs('skey')[0] + skey = asyncssh.load_keypairs('skey_ecdsa')[0] sig = skey.sign(data) return String(Byte(KEYSIGN_VERSION) + String(sig)), b'' diff --git a/tests/pkcs11_stub.py b/tests/pkcs11_stub.py index 2902845..472ed2b 100644 --- a/tests/pkcs11_stub.py +++ b/tests/pkcs11_stub.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021 by Ron Frederick and others. +# Copyright (c) 2020-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -23,7 +23,7 @@ import asyncssh from asyncssh.asn1 import der_decode from asyncssh.pkcs11 import pkcs11_available -from asyncssh.public_key import generate_private_key +from .util import get_test_key if pkcs11_available: # pragma: no branch import pkcs11 @@ -55,8 +55,8 @@ def _encode_public(key): class _PKCS11Key: """Stub for unit testing PKCS#11 keys""" - def __init__(self, alg, key_type, key_label, key_id): - self._priv = generate_private_key(alg, comment=key_label) + def __init__(self, alg_name, key_type, key_label, key_id): + self._priv = get_test_key(alg_name, key_id, comment=key_label) self.key_type = key_type self.label = key_label self.id = key_id @@ -65,7 +65,7 @@ def get_cert(self): """Return self-signed X.509 cert for this key""" return self._priv.generate_x509_user_certificate( - self._priv, 'OU=%s,CN=ckey' % self.label) + self._priv, f'OU={self.label},CN=ckey') def get_public(self): """Return public key corresponding to this key""" diff --git a/tests/server.py b/tests/server.py index 5df4247..58153f6 100644 --- a/tests/server.py +++ b/tests/server.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016-2021 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -26,11 +26,13 @@ import signal import socket import subprocess +import sys import asyncssh from asyncssh.misc import async_context_manager -from .util import AsyncTestCase, all_tasks, current_task, run, x509_available +from .util import AsyncTestCase, all_tasks, current_task, get_test_key +from .util import run, x509_available class Server(asyncssh.SSHServer): @@ -109,12 +111,12 @@ async def asyncSetUpClass(cls): # pylint: disable=too-many-statements - ckey = asyncssh.generate_private_key('ssh-rsa') + ckey = get_test_key('ssh-rsa') ckey.write_private_key('ckey') ckey.write_private_key('ckey_encrypted', passphrase='passphrase') ckey.write_public_key('ckey.pub') - ckey_ecdsa = asyncssh.generate_private_key('ecdsa-sha2-nistp256') + ckey_ecdsa = get_test_key('ecdsa-sha2-nistp256') ckey_ecdsa.write_private_key('ckey_ecdsa') ckey_ecdsa.write_public_key('ckey_ecdsa.pub') @@ -122,19 +124,22 @@ async def asyncSetUpClass(cls): principals=['ckey']) ckey_cert.write_certificate('ckey-cert.pub') - skey = asyncssh.generate_private_key('ssh-rsa') + skey = get_test_key('ssh-rsa', 1) skey.write_private_key('skey') skey.write_public_key('skey.pub') - skey_ecdsa = asyncssh.generate_private_key('ecdsa-sha2-nistp256') + skey_ecdsa = get_test_key('ecdsa-sha2-nistp256', 1) skey_ecdsa.write_private_key('skey_ecdsa') skey_ecdsa.write_public_key('skey_ecdsa.pub') - skey_cert = skey.generate_host_certificate(skey, 'name', - principals=['127.0.0.1', - 'localhost']) + skey_cert = skey.generate_host_certificate( + skey, 'name', principals=['127.0.0.1', 'localhost']) skey_cert.write_certificate('skey-cert.pub') + skey_ecdsa_cert = skey_ecdsa.generate_host_certificate( + skey_ecdsa, 'name', principals=['127.0.0.1', 'localhost']) + skey_ecdsa_cert.write_certificate('skey_ecdsa-cert.pub') + exp_cert = skey.generate_host_certificate(skey, 'name', valid_after='-2d', valid_before='-1d') @@ -155,7 +160,7 @@ async def asyncSetUpClass(cls): skey_x509_self.append_certificate('skey_x509_self', 'pem') skey_x509_self.write_certificate('skey_x509_self.pem', 'pem') - root_ca_key = asyncssh.generate_private_key('ssh-rsa') + root_ca_key = get_test_key('ssh-rsa', 2) root_ca_key.write_private_key('root_ca_key') root_ca_cert = root_ca_key.generate_x509_ca_certificate( @@ -163,7 +168,7 @@ async def asyncSetUpClass(cls): root_ca_cert.write_certificate('root_ca_cert.pem', 'pem') root_ca_cert.write_certificate('root_ca_cert.pub') - int_ca_key = asyncssh.generate_private_key('ssh-rsa') + int_ca_key = get_test_key('ssh-rsa', 3) int_ca_key.write_private_key('int_ca_key') int_ca_cert = root_ca_key.generate_x509_ca_certificate( @@ -241,7 +246,7 @@ async def asyncSetUpClass(cls): cls._server_addr = '127.0.0.1' cls._server_port = sock.getsockname()[1] - host = '[%s]:%d,localhost ' % (cls._server_addr, cls._server_port) + host = f'[{cls._server_addr}]:{cls._server_port},localhost ' with open('known_hosts', 'w') as known_hosts: known_hosts.write(host) @@ -265,17 +270,20 @@ async def asyncSetUpClass(cls): if 'XAUTHORITY' in os.environ: # pragma: no cover del os.environ['XAUTHORITY'] - try: - output = run('ssh-agent -a agent 2>/dev/null') - except subprocess.CalledProcessError: # pragma: no cover - cls._agent_pid = None - else: - cls._agent_pid = int(output.splitlines()[2].split()[3][:-1]) + if sys.platform != 'win32': + try: + output = run('ssh-agent -a agent 2>/dev/null') + except subprocess.CalledProcessError: # pragma: no cover + cls._agent_pid = None + else: + cls._agent_pid = int(output.splitlines()[2].split()[3][:-1]) - os.environ['SSH_AUTH_SOCK'] = 'agent' + os.environ['SSH_AUTH_SOCK'] = 'agent' - async with asyncssh.connect_agent() as agent: - await agent.add_keys([ckey_ecdsa, (ckey, ckey_cert)]) + async with asyncssh.connect_agent() as agent: + await agent.add_keys([ckey_ecdsa, (ckey, ckey_cert)]) + else: # pragma: no cover + cls._agent_pid = None with open('ssh-keysign', 'wb'): pass @@ -284,14 +292,14 @@ async def asyncSetUpClass(cls): async def asyncTearDownClass(cls): """Shut down test server and agent""" + cls._server.close() + await cls._server.wait_closed() + tasks = all_tasks() tasks.remove(current_task()) await asyncio.gather(*tasks, return_exceptions=True) - cls._server.close() - await cls._server.wait_closed() - if cls._agent_pid: # pragma: no branch os.kill(cls._agent_pid, signal.SIGTERM) @@ -325,6 +333,22 @@ async def connect_reverse(self, options=None, gss_host=None, **kwargs): self._server_port, options=options, **kwargs) + @async_context_manager + async def run_client(self, sock, config=(), options=None, **kwargs): + """Run an SSH client on an already-connected socket""" + + return await asyncssh.run_client(sock, config, options, **kwargs) + + @async_context_manager + async def run_server(self, sock, config=(), options=None, **kwargs): + """Run an SSH server on an already-connected socket""" + + options = asyncssh.SSHServerConnectionOptions(options, + server_factory=Server, + server_host_keys=['skey']) + + return await asyncssh.run_server(sock, config, options, **kwargs) + async def create_connection(self, client_factory, **kwargs): """Create a connection to the test server""" diff --git a/tests/sk_stub.py b/tests/sk_stub.py index ffba14c..0926e4e 100644 --- a/tests/sk_stub.py +++ b/tests/sk_stub.py @@ -27,7 +27,7 @@ from asyncssh.asn1 import der_encode, der_decode from asyncssh.crypto import ECDSAPrivateKey, EdDSAPrivateKey from asyncssh.packet import Byte, UInt32 -from asyncssh.sk import sk_available +from asyncssh.sk import sk_available, sk_webauthn_prefix if sk_available: # pragma: no branch from asyncssh.sk import SSH_SK_ECDSA, SSH_SK_ED25519 @@ -86,16 +86,45 @@ def __init__(self, auth_data): self.auth_data = auth_data -class _CtapStub: - """Stub for unit testing U2F security key support""" +class _AttestationResponse: + """Security key attestation response""" - _version = None + def __init__(self, attestation_object): + self.attestation_object = attestation_object - def __init__(self, dev): - if dev.version != self._version: - raise ValueError('Wrong protocol version') - self.dev = dev +class _AuthenticatorData: + """Security key authenticator data in aseertion""" + + def __init__(self, flags, counter): + self.flags = flags + self.counter = counter + + +class _AssertionResponse: + """Security key aseertion response""" + + def __init__(self, client_data, auth_data, signature): + self.client_data = client_data + self.authenticator_data = auth_data + self.signature = signature + + +class _AssertionSelection: + """Security key assertion response list""" + + def __init__(self, assertions): + self._assertions = assertions + + def get_response(self, index): + """Return the assertion at specified index""" + + return self._assertions[index] + + +class _CtapStub: + """Stub for unit testing U2F security key support""" + @staticmethod def _enroll(alg): @@ -135,11 +164,8 @@ def _sign(message_hash, app_hash, key_handle, flags): class Ctap1(_CtapStub): """Stub for unit testing U2F security keys using CTAP version 1""" - _version = 1 - def __init__(self, dev): - super().__init__(dev) - + self.dev = dev self._polled = False def _poll(self): @@ -182,7 +208,11 @@ def authenticate(self, message_hash, app_hash, key_handle): class Ctap2(_CtapStub): """Stub for unit testing U2F security keys using CTAP version 2""" - _version = 2 + def __init__(self, dev): + if dev.version != 2: + raise ValueError('Wrong protocol version') + + self.dev = dev def make_credential(self, client_data_hash, rp, user, key_params, options, pin_uv_param, pin_uv_protocol): @@ -228,6 +258,50 @@ def get_assertions(self, application, message_hash, allow_creds, options): return [_Assertion(_AuthData(flags, counter), sig)] +class WindowsClient(_CtapStub): + """Stub for unit testing U2F security keys via Windows WebAuthn""" + + def __init__(self, origin, verify): + self._origin = origin + self._verify = verify + + def make_credential(self, options): + """Make a credential using Windows WebAuthN API""" + + self._verify(options['rp']['id'], self._origin) + + alg = options['pubKeyCredParams'][0]['alg'] + + public_key, key_handle = self._enroll(alg) + + cdata = _CredentialData(alg, public_key, key_handle) + + return _AttestationResponse(_Credential(_CredentialAuthData(cdata))) + + def get_assertion(self, options): + """Get assertion using Windows WebAuthN API""" + + self._verify(options['rpId'], self._origin) + + challenge = options['challenge'] + application = options['rpId'] + key_handle = options['allowCredentials'][0]['id'] + flags = SSH_SK_USER_PRESENCE_REQD + + app_hash = sha256(application.encode()).digest() + + data = sk_webauthn_prefix(challenge, application) + b'}' + message_hash = sha256(data).digest() + + flags, counter, sig = self._sign(message_hash, app_hash, + key_handle, flags) + + auth_data = _AuthenticatorData(flags, counter) + assertion = _AssertionResponse(data, auth_data, sig) + + return _AssertionSelection([assertion]) + + class CredentialManagement: """Stub for unit testing U2F security device resident keys""" @@ -301,13 +375,15 @@ class PinProtocolV1: VERSION = 1 -def stub_sk(devices): +def stub_sk(devices, use_webauthn=False): """Stub out security key module functions for unit testing""" devices = list(map(Device, devices)) old_ctap1 = asyncssh.sk.Ctap1 old_ctap2 = asyncssh.sk.Ctap2 + old_windows_client = asyncssh.sk.WindowsClient + old_use_webauthn = asyncssh.sk.sk_use_webauthn old_client_pin = asyncssh.sk.ClientPin old_cred_mgmt = asyncssh.sk.CredentialManagement old_pin_proto = asyncssh.sk.PinProtocolV1 @@ -315,21 +391,27 @@ def stub_sk(devices): asyncssh.sk.Ctap1 = Ctap1 asyncssh.sk.Ctap2 = Ctap2 + asyncssh.sk.WindowsClient = WindowsClient + asyncssh.sk.sk_use_webauthn = use_webauthn + asyncssh.sk_ecdsa.sk_use_webauthn = use_webauthn asyncssh.sk.ClientPin = ClientPin asyncssh.sk.CredentialManagement = CredentialManagement asyncssh.sk.PinProtocolV1 = PinProtocolV1 asyncssh.sk.CtapHidDevice.list_devices = lambda: iter(devices) - return old_ctap1, old_ctap2, old_client_pin, \ - old_cred_mgmt, old_pin_proto, old_list_devices + return old_ctap1, old_ctap2, old_windows_client, old_use_webauthn, \ + old_client_pin, old_cred_mgmt, old_pin_proto, old_list_devices -def unstub_sk(old_ctap1, old_ctap2, old_client_pin, old_cred_mgmt, - old_pin_proto, old_list_devices): +def unstub_sk(old_ctap1, old_ctap2, old_windows_client, old_use_webauthn, + old_client_pin, old_cred_mgmt, old_pin_proto, old_list_devices): """Restore security key module functions""" asyncssh.sk.Ctap1 = old_ctap1 asyncssh.sk.Ctap2 = old_ctap2 + asyncssh.sk.WindowsClient = old_windows_client + asyncssh.sk.sk_use_webauthn = old_use_webauthn + asyncssh.sk_ecdsa.sk_use_webauthn = old_use_webauthn asyncssh.sk.ClientPin = old_client_pin asyncssh.sk.CredentialManagement = old_cred_mgmt asyncssh.sk.PinProtocolV1 = old_pin_proto diff --git a/tests/sspi_stub.py b/tests/sspi_stub.py index df79836..99767eb 100644 --- a/tests/sspi_stub.py +++ b/tests/sspi_stub.py @@ -21,11 +21,12 @@ """Stub SSPI module for unit tests""" + import sys from .gss_stub import step -if sys.platform == 'win32': +if sys.platform == 'win32': # pragma: no cover from asyncssh.gss_win32 import ASC_RET_INTEGRITY, ISC_RET_INTEGRITY from asyncssh.gss_win32 import SECPKG_ATTR_NATIVE_NAMES, SSPIError @@ -53,7 +54,7 @@ def QueryContextAttributes(self, attr): # pylint: disable=invalid-name if attr == SECPKG_ATTR_NATIVE_NAMES: return ['user@TEST', 'host@TEST'] - else: + else: # pragma: no cover return None @@ -126,7 +127,10 @@ def sign(self, data): # pylint: disable=no-self-use,unused-argument - return b'fail' if 'fail' in self._host else 'succeed' + if 'sign_error' in self._host: + raise SSPIError('Signing error') + + return b'fail' if 'verify_error' in self._host else b'' def verify(self, data, sig): """Verify a signature for a block of data""" diff --git a/tests/test_agent.py b/tests/test_agent.py index 927f87a..322adec 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016-2020 by Ron Frederick and others. +# Copyright (c) 2016-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -36,7 +36,7 @@ from asyncssh.packet import Byte, String, UInt32 from .sk_stub import sk_available, patch_sk -from .util import AsyncTestCase, asynctest, run +from .util import AsyncTestCase, asynctest, get_test_key, run, try_remove def agent_test(func): @@ -85,7 +85,7 @@ async def stop(self): self._server.close() await self._server.wait_closed() - os.remove(self._path) + try_remove(self._path) class _TestAgent(AsyncTestCase): @@ -99,7 +99,7 @@ def set_askpass(status): """Set return status for ssh-askpass""" with open('ssh-askpass', 'w') as f: - f.write('#!/bin/sh\nexit %d\n' % status) + f.write(f'#!/bin/sh\nexit {status}\n') os.chmod('ssh-askpass', 0o755) # Pylint doesn't like mixed case method names, but this was chosen to @@ -173,13 +173,13 @@ async def test_get_keys(self, agent): async def test_sign(self, agent): """Test signing a block of data using the agent""" - algs = ['ssh-dss', 'ssh-rsa', 'ecdsa-sha2-nistp256'] + algs = ['ssh-rsa', 'ecdsa-sha2-nistp256'] if ed25519_available: # pragma: no branch algs.append('ssh-ed25519') for alg_name in algs: - key = asyncssh.generate_private_key(alg_name) + key = get_test_key(alg_name) pubkey = key.convert_to_public() cert = key.generate_user_certificate(key, 'name') @@ -187,6 +187,7 @@ async def test_sign(self, agent): agent_keys = await agent.get_keys() for agent_key in agent_keys: + agent_key.set_sig_algorithm(agent_key.sig_algorithms[0]) sig = await agent_key.sign_async(b'test') self.assertTrue(pubkey.verify(b'test', sig)) @@ -196,10 +197,10 @@ async def test_sign(self, agent): async def test_set_certificate(self, agent): """Test setting certificate on an existing keypair""" - key = asyncssh.generate_private_key('ssh-rsa') + key = get_test_key('ssh-rsa') cert = key.generate_user_certificate(key, 'name') - key2 = asyncssh.generate_private_key('ssh-rsa') + key2 = get_test_key('ssh-rsa', 1) cert2 = key.generate_user_certificate(key2, 'name') await agent.add_keys([key]) @@ -222,7 +223,7 @@ async def test_set_certificate(self, agent): async def test_reconnect(self, agent): """Test reconnecting to the agent after closing it""" - key = asyncssh.generate_private_key('ssh-rsa') + key = get_test_key('ecdsa-sha2-nistp256') pubkey = key.convert_to_public() async with agent: @@ -241,7 +242,7 @@ async def test_add_remove_keys(self, agent): agent_keys = await agent.get_keys() self.assertEqual(len(agent_keys), 0) - key = asyncssh.generate_private_key('ssh-rsa') + key = get_test_key('ssh-rsa') await agent.add_keys([key]) agent_keys = await agent.get_keys() self.assertEqual(len(agent_keys), 1) @@ -270,7 +271,7 @@ async def test_add_remove_keys(self, agent): async def test_add_nonlocal(self, agent): """Test failure when adding a non-local key to an agent""" - key = asyncssh.generate_private_key('ssh-rsa') + key = get_test_key('ssh-rsa') async with agent: await agent.add_keys([key]) @@ -284,7 +285,7 @@ async def test_add_keys_failure(self, agent): """Test failure adding keys to the agent""" os.mkdir('.ssh', 0o700) - key = asyncssh.generate_private_key('ssh-rsa') + key = get_test_key('ssh-rsa') key.write_private_key(Path('.ssh', 'id_rsa')) try: @@ -309,8 +310,7 @@ async def test_add_keys_failure(self, agent): async def test_add_sk_keys(self): """Test adding U2F security keys""" - key = asyncssh.generate_private_key( - 'sk-ecdsa-sha2-nistp256@openssh.com') + key = get_test_key('sk-ecdsa-sha2-nistp256@openssh.com') cert = key.generate_user_certificate(key, 'test') mock_agent = _Agent(Byte(SSH_AGENT_SUCCESS)) @@ -321,9 +321,8 @@ async def test_add_sk_keys(self): async with agent: self.assertIsNone(await agent.add_keys([keypair])) - async with agent: - with self.assertRaises(asyncssh.KeyExportError): - await agent.add_keys([key.convert_to_public()]) + with self.assertRaises(asyncssh.KeyExportError): + await agent.add_keys([key.convert_to_public()]) await mock_agent.stop() @@ -333,8 +332,7 @@ async def test_add_sk_keys(self): async def test_get_sk_keys(self): """Test getting U2F security keys""" - key = asyncssh.generate_private_key( - 'sk-ecdsa-sha2-nistp256@openssh.com') + key = get_test_key('sk-ecdsa-sha2-nistp256@openssh.com') cert = key.generate_user_certificate(key, 'test') mock_agent = _Agent(Byte(SSH_AGENT_IDENTITIES_ANSWER) + UInt32(2) + @@ -374,7 +372,7 @@ async def test_add_remove_smartcard_keys(self): async def test_confirm(self, agent): """Test confirmation of key""" - key = asyncssh.generate_private_key('ssh-rsa') + key = get_test_key('ecdsa-sha2-nistp256') pubkey = key.convert_to_public() await agent.add_keys([key], confirm=True) @@ -396,7 +394,7 @@ async def test_confirm(self, agent): async def test_lock(self, agent): """Test lock and unlock""" - key = asyncssh.generate_private_key('ssh-rsa') + key = get_test_key('ecdsa-sha2-nistp256') pubkey = key.convert_to_public() await agent.add_keys([key]) @@ -458,7 +456,7 @@ async def test_query_extensions(self): async def test_unknown_key(self, agent): """Test failure when signing with an unknown key""" - key = asyncssh.generate_private_key('ssh-rsa') + key = get_test_key('ssh-rsa') with self.assertRaises(ValueError): await agent.sign(key.public_data, b'test') @@ -474,7 +472,7 @@ async def test_double_close(self, agent): async def test_errors(self): """Test getting error responses from SSH agent""" - key = asyncssh.generate_private_key('ssh-rsa') + key = get_test_key('ssh-rsa') keypair = asyncssh.load_keypairs(key)[0] for response in (None, b'', Byte(SSH_AGENT_FAILURE), b'\xff'): diff --git a/tests/test_auth.py b/tests/test_auth.py index ebecfe3..d49f08b 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2022 by Ron Frederick and others. +# Copyright (c) 2015-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -21,6 +21,7 @@ """Unit tests for authentication""" import asyncio +import inspect import unittest import asyncssh @@ -34,7 +35,7 @@ from asyncssh.packet import SSHPacket, Boolean, Byte, NameList, String from .util import asynctest, gss_available, patch_gss -from .util import AsyncTestCase, ConnectionStub +from .util import AsyncTestCase, ConnectionStub, get_test_key class _AuthConnectionStub(ConnectionStub): @@ -45,7 +46,7 @@ def connection_lost(self, exc): raise NotImplementedError - def process_packet(self, data): + async def process_packet(self, data): """Process an incoming packet""" raise NotImplementedError @@ -96,7 +97,7 @@ def __init__(self, method, gss_host=None, override_gss_mech=False, password_change_prompt, kbdint_auth, kbdint_challenge, success), False) - self._gss = GSSClient(gss_host, False) if gss_host else None + self._gss = GSSClient(gss_host, None, False) if gss_host else None self._client_host_key = client_host_key self._client_host_cert = client_host_cert @@ -126,7 +127,7 @@ def connection_lost(self, exc=None): self.close() - def process_packet(self, data): + async def process_packet(self, data): """Process an incoming packet""" packet = SSHPacket(data) @@ -154,16 +155,21 @@ def process_packet(self, data): self._auth = None self._auth_waiter = None else: - self._auth.process_packet(pkttype, None, packet) + result = self._auth.process_packet(pkttype, None, packet) + + if inspect.isawaitable(result): + await result async def get_auth_result(self): """Return the result of the authentication""" return await self._auth_waiter - def try_next_auth(self): + def try_next_auth(self, *, next_method=False): """Handle a request to move to another form of auth""" + # pylint: disable=unused-argument + # Report that the current auth attempt failed self._auth_waiter.set_result((False, self._password_changed)) self._auth = None @@ -259,7 +265,7 @@ def __init__(self, peer=None, gss_host=None, override_gss_mech=False, kbdint_challenge=False, success=False): super().__init__(peer, True) - self._gss = GSSServer(gss_host) if gss_host else None + self._gss = GSSServer(gss_host, None) if gss_host else None self._override_gss_mech = override_gss_mech self._host_based_auth = host_based_auth @@ -285,7 +291,7 @@ def connection_lost(self, exc=None): self.close() - def process_packet(self, data): + async def process_packet(self, data): """Process an incoming packet""" packet = SSHPacket(data) @@ -308,7 +314,10 @@ def process_packet(self, data): else: self._auth = lookup_server_auth(self, 'user', method, packet) else: - self._auth.process_packet(pkttype, None, packet) + result = self._auth.process_packet(pkttype, None, packet) + + if inspect.isawaitable(result): + await result def send_userauth_failure(self, partial_success): """Send a user authentication failure response""" @@ -317,7 +326,7 @@ def send_userauth_failure(self, partial_success): self.send_userauth_packet(MSG_USERAUTH_FAILURE, NameList([]), Boolean(partial_success)) - def send_userauth_success(self): + async def send_userauth_success(self): """Send a user authentication success response""" self._auth = None @@ -517,7 +526,7 @@ async def test_gss_auth(self): async def test_hostbased_auth(self): """Unit test host-based authentication""" - hkey = asyncssh.generate_private_key('ssh-rsa') + hkey = get_test_key('ecdsa-sha2-nistp256') cert = hkey.generate_host_certificate(hkey, 'host') with self.subTest('Host-based auth not available'): @@ -541,7 +550,7 @@ async def test_hostbased_auth(self): async def test_publickey_auth(self): """Unit test public key authentication""" - ckey = asyncssh.generate_private_key('ssh-rsa') + ckey = get_test_key('ecdsa-sha2-nistp256') cert = ckey.generate_user_certificate(ckey, 'name') with self.subTest('Public key auth not available'): diff --git a/tests/test_auth_keys.py b/tests/test_auth_keys.py index 27d4190..9e1a05c 100644 --- a/tests/test_auth_keys.py +++ b/tests/test_auth_keys.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2020 by Ron Frederick and others. +# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -24,7 +24,7 @@ import asyncssh -from .util import TempDirTestCase, x509_available +from .util import TempDirTestCase, get_test_key, x509_available class _TestAuthorizedKeys(TempDirTestCase): @@ -43,12 +43,12 @@ def setUpClass(cls): super().setUpClass() for i in range(3): - key = asyncssh.generate_private_key('ssh-rsa') + key = get_test_key('ssh-rsa', i) cls.keylist.append(key.export_public_key().decode('ascii')) cls.imported_keylist.append(key.convert_to_public()) if x509_available: # pragma: no branch - subject = 'CN=cert%s' % i + subject = f'CN=cert{i}' cert = key.generate_x509_user_certificate(key, subject) cls.certlist.append(cert.export_certificate().decode('ascii')) cls.imported_certlist.append(cert) @@ -63,7 +63,7 @@ def build_keys(self, keys, x509=False, from_file=False): keynum = 1 if 'cert-authority' in options else 0 key_or_cert = (self.certlist if x509 else self.keylist)[keynum] - auth_keys += '%s%s' % (options, key_or_cert) + auth_keys += options + key_or_cert if from_file: with open('authorized_keys', 'w') as f: @@ -197,7 +197,7 @@ def test_cert_authority_with_subject(self): def test_non_root_ca(self): """Test error on non-root X.509 CA""" - key = asyncssh.generate_private_key('ssh-rsa') + key = get_test_key('ssh-rsa') cert = key.generate_x509_user_certificate(key, 'CN=a', 'CN=b') data = 'cert-authority ' + cert.export_certificate().decode('ascii') diff --git a/tests/test_channel.py b/tests/test_channel.py index d8b0c82..0974f0e 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016-2022 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -23,6 +23,7 @@ import asyncio import os import tempfile +import unittest from signal import SIGINT from unittest.mock import patch @@ -40,6 +41,8 @@ from asyncssh.packet import Byte, String, UInt32 from asyncssh.public_key import CERT_TYPE_USER from asyncssh.stream import SSHTCPStreamSession, SSHUNIXStreamSession +from asyncssh.stream import SSHTunTapStreamSession +from asyncssh.tuntap import SSH_TUN_MODE_POINTTOPOINT, SSH_TUN_MODE_ETHERNET from .server import Server, ServerTestCase from .util import asynctest, echo, make_certificate @@ -53,9 +56,7 @@ class _ClientChannel(asyncssh.SSHClientChannel): def _send_request(self, request, *args, want_reply=False): """Send a channel request""" - if request == b'env' and args[1] == String('invalid'): - args = args[:1] + (String(b'\xff'),) - elif request == b'pty-req': + if request == b'pty-req': if args[5][-6:-5] == Byte(PTY_OP_PARTIAL): args = args[:5] + (String(args[5][4:-5]),) elif args[5][-6:-5] == Byte(PTY_OP_NO_END): @@ -246,6 +247,15 @@ def session_started(self): chan.close() +class _ClientSessionCleanupError(asyncssh.SSHClientSession): + """Test of exception during client session cleanup""" + + def connection_lost(self, exc): + """Raise an error when a client session is cleaned up""" + + raise RuntimeError('Exception in session cleanup test') + + class _ChannelServer(Server): """Server for testing the AsyncSSH channel API""" @@ -270,7 +280,7 @@ async def _begin_session(self, stdin, stdout, stderr): elif action == 'agent': try: async with asyncssh.connect_agent(self._conn) as agent: - stdout.write(str(len((await agent.get_keys()))) + '\n') + stdout.write(str(len(await agent.get_keys())) + '\n') except (OSError, asyncssh.ChannelOpenError): stdout.channel.exit(1) elif action == 'agent_sock': @@ -279,7 +289,7 @@ async def _begin_session(self, stdin, stdout, stderr): if agent_path: async with asyncssh.connect_agent(agent_path) as agent: await asyncio.sleep(0.1) - stdout.write(str(len((await agent.get_keys()))) + '\n') + stdout.write(str(len(await agent.get_keys())) + '\n') else: stdout.channel.exit(1) elif action == 'rejected_agent': @@ -343,6 +353,22 @@ async def _begin_session(self, stdin, stdout, stderr): await chan.accept(SSHUNIXStreamSession, b'\xff') except asyncssh.ChannelOpenError: stdout.channel.exit(1) + elif action == 'rejected_tun_request': + chan = self._conn.create_tuntap_channel() + + try: + await chan.open(SSHTunTapStreamSession, + SSH_TUN_MODE_POINTTOPOINT, 0) + except asyncssh.ChannelOpenError: + stdout.channel.exit(1) + elif action == 'rejected_tap_request': + chan = self._conn.create_tuntap_channel() + + try: + await chan.open(SSHTunTapStreamSession, + SSH_TUN_MODE_ETHERNET, 0) + except asyncssh.ChannelOpenError: + stdout.channel.exit(1) elif action == 'late_auth_banner': try: self._conn.send_auth_banner('auth banner') @@ -355,8 +381,21 @@ async def _begin_session(self, stdin, stdout, stderr): stdin.channel.send_packet(MSG_CHANNEL_OPEN_FAILURE, UInt32(0), String(''), String('')) elif action == 'env': + value = stdin.channel.get_environment_bytes().get(b'TEST', b'') + stdout.write(value.decode('utf-8', 'backslashreplace') + '\n') + elif action == 'env_binary_key': + value = stdin.channel.get_environment_bytes().get(b'TEST\xff', b'') + stdout.write(value.decode('utf-8', 'backslashreplace') + '\n') + elif action == 'env_str': value = stdin.channel.get_environment().get('TEST', '') stdout.write(value + '\n') + elif action == 'env_str_cached': + value1 = stdin.channel.get_environment().get('TEST', '') + value2 = stdin.channel.get_environment().get('TEST', '') + stdout.write(value1 + value2 + '\n') + elif action == 'env_non_string_key': + value = stdin.channel.get_environment().get('1', '') + stdout.write(value + '\n') elif action == 'term': chan = stdin.channel info = str((chan.get_terminal_type(), chan.get_terminal_size(), @@ -398,17 +437,21 @@ async def _begin_session(self, stdin, stdout, stderr): elif action == 'empty_data': stdin.channel.send_packet(MSG_CHANNEL_DATA, String('')) elif action == 'partial_unicode': - data = '\xff\xff'.encode('utf-8') + data = '\xff\xff'.encode() stdin.channel.send_packet(MSG_CHANNEL_DATA, String(data[:3])) stdin.channel.send_packet(MSG_CHANNEL_DATA, String(data[3:])) elif action == 'partial_unicode_at_eof': - data = '\xff\xff'.encode('utf-8') + data = '\xff\xff'.encode() stdin.channel.send_packet(MSG_CHANNEL_DATA, String(data[:3])) elif action == 'unicode_error': stdin.channel.send_packet(MSG_CHANNEL_DATA, String(b'\xff')) elif action == 'data_past_window': stdin.channel.send_packet(MSG_CHANNEL_DATA, String(2*1025*1024*'\0')) + elif action == 'ext_data_past_window': + stdin.channel.send_packet(MSG_CHANNEL_EXTENDED_DATA, + UInt32(asyncssh.EXTENDED_DATA_STDERR), + String(2*1025*1024*'\0')) elif action == 'data_after_eof': stdin.channel.send_packet(MSG_CHANNEL_EOF) stdout.write('xxx') @@ -733,7 +776,7 @@ async def test_unknown_channel_request(self): async with self.connect() as conn: chan, _ = await _create_session(conn) - self.assertFalse((await chan.make_request('unknown'))) + self.assertFalse(await chan.make_request('unknown')) @asynctest async def test_invalid_channel_request(self): @@ -828,14 +871,14 @@ async def test_unneeded_resume_reading(self): chan.close() @asynctest - async def test_agent_forwarding(self): - """Test SSH agent forwarding""" + async def test_agent_forwarding_explicit(self): + """Test SSH agent forwarding with explicit path""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') async with self.connect(username='ckey', - agent_forwarding=True) as conn: + agent_forwarding='agent') as conn: chan, session = await _create_session(conn, 'agent') await chan.wait_closed() @@ -908,6 +951,18 @@ async def test_invalid_unix_listener(self): await self._check_action('invalid_unix_listener', None) + @asynctest + async def test_rejected_tun_request(self): + """Test receiving inbound TUN request""" + + await self._check_action('rejected_tun_request', 1) + + @asynctest + async def test_rejected_tap_request(self): + """Test receiving inbound TAP request""" + + await self._check_action('rejected_tap_request', 1) + @asynctest async def test_agent_forwarding_failure(self): """Test failure of SSH agent forwarding""" @@ -1097,10 +1152,23 @@ async def test_term_modes_incomplete(self): @asynctest async def test_env(self): - """Test setting environment""" + """Test setting environment with byte strings""" async with self.connect() as conn: chan, session = await _create_session(conn, 'env', + env={b'TEST': b'test'}) + + await chan.wait_closed() + + result = ''.join(session.recv_buf[None]) + self.assertEqual(result, 'test\n') + + @asynctest + async def test_env_str(self): + """Test setting environment using Unicode strings""" + + async with self.connect() as conn: + chan, session = await _create_session(conn, 'env_str', env={'TEST': 'test'}) await chan.wait_closed() @@ -1108,6 +1176,92 @@ async def test_env(self): result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\n') + @asynctest + async def test_env_str_cached(self): + """Test caching of Unicode string environment dict on server""" + + async with self.connect() as conn: + chan, session = await _create_session(conn, 'env_str_cached', + env={'TEST': 'test'}) + + await chan.wait_closed() + + result = ''.join(session.recv_buf[None]) + self.assertEqual(result, 'testtest\n') + + @asynctest + async def test_env_invalid_str(self): + """Test trying to access binary envionment value as a Unicode string""" + + async with self.connect() as conn: + chan, session = await _create_session(conn, 'env_str', + env={'TEST': b'test\xff'}) + + await chan.wait_closed() + + result = ''.join(session.recv_buf[None]) + self.assertEqual(result, '\n') + + @asynctest + async def test_env_binary_key(self): + """Test setting environment with binary data in key""" + + async with self.connect() as conn: + chan, session = await _create_session(conn, 'env_binary_key', + env={b'TEST\xff': 'test'}) + + await chan.wait_closed() + + result = ''.join(session.recv_buf[None]) + self.assertEqual(result, 'test\n') + + @asynctest + async def test_env_binary_value(self): + """Test setting environment with binary data in value""" + + async with self.connect() as conn: + chan, session = await _create_session(conn, 'env', + env={'TEST': b'test\xff'}) + + await chan.wait_closed() + + result = ''.join(session.recv_buf[None]) + self.assertEqual(result, 'test\\xff\n') + + @asynctest + async def test_env_non_string_key(self): + """Test setting environment with non-string as a key""" + + async with self.connect() as conn: + chan, session = await _create_session(conn, 'env_non_string_key', + env={1: 'test'}) + + await chan.wait_closed() + + result = ''.join(session.recv_buf[None]) + self.assertEqual(result, 'test\n') + + @asynctest + async def test_env_non_string_value(self): + """Test setting environment with non-string as a value""" + + async with self.connect() as conn: + chan, session = await _create_session(conn, 'env', + env={'TEST': 1}) + + await chan.wait_closed() + + result = ''.join(session.recv_buf[None]) + self.assertEqual(result, '1\n') + + @asynctest + async def test_invalid_env(self): + """Test sending invalid environment""" + + async with self.connect() as conn: + with self.assertRaises(ValueError): + await _create_session(conn, 'env', env=1) + @asynctest async def test_env_from_connect(self): """Test setting environment on connection""" @@ -1133,6 +1287,32 @@ async def test_env_list(self): result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\n') + @asynctest + async def test_env_list_binary(self): + """Test setting environment using a list of name=value byte strings""" + + async with self.connect() as conn: + chan, session = await _create_session(conn, 'env', + env=[b'TEST=test\xff']) + + await chan.wait_closed() + + result = ''.join(session.recv_buf[None]) + self.assertEqual(result, 'test\\xff\n') + + @asynctest + async def test_env_tuple(self): + """Test setting environment using a tuple of name=value strings""" + + async with self.connect() as conn: + chan, session = await _create_session(conn, 'env', + env=('TEST=test',)) + + await chan.wait_closed() + + result = ''.join(session.recv_buf[None]) + self.assertEqual(result, 'test\n') + @asynctest async def test_invalid_env_list(self): """Test setting environment using an invalid string""" @@ -1158,6 +1338,25 @@ async def test_send_env(self): result = ''.join(session.recv_buf[None]) self.assertEqual(result, 'test\n') + @unittest.skipUnless(os.supports_bytes_environ, + 'skip binary send env if not supported by OS') + @asynctest + async def test_send_env_binary(self): + """Test sending local environment using a byte string""" + + async with self.connect() as conn: + try: + os.environb[b'TEST'] = b'test\xff' + chan, session = await _create_session(conn, 'env', + send_env=[b'TEST']) + finally: + del os.environb[b'TEST'] + + await chan.wait_closed() + + result = ''.join(session.recv_buf[None]) + self.assertEqual(result, 'test\\xff\n') + @asynctest async def test_send_env_from_connect(self): """Test sending local environment on connection""" @@ -1193,20 +1392,6 @@ async def test_mixed_env(self): result = ''.join(session.recv_buf[None]) self.assertEqual(result, '2\n') - @asynctest - async def test_invalid_env(self): - """Test sending invalid environment""" - - with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): - async with self.connect() as conn: - chan, session = await _create_session( - conn, 'env', env={'TEST': 'invalid'}) - - await chan.wait_closed() - - result = ''.join(session.recv_buf[None]) - self.assertEqual(result, '\n') - @asynctest async def test_xon_xoff_enable(self): """Test enabling XON/XOFF flow control""" @@ -1478,6 +1663,15 @@ async def test_data_past_window(self): await chan.wait_closed() + @asynctest + async def test_ext_data_past_window(self): + """Test receiving an extended data packet past the advertised window""" + + async with self.connect() as conn: + chan, _ = await _create_session(conn, 'ext_data_past_window') + + await chan.wait_closed() + @asynctest async def test_data_after_eof(self): """Test receiving data after EOF""" @@ -1576,6 +1770,13 @@ async def test_unknown_action(self): await chan.wait_closed() self.assertEqual(session.exit_status, 255) + @asynctest + async def test_client_session_cleanup_error(self): + """Test error in client session cleanup""" + + async with self.connect() as conn: + await conn.create_session(_ClientSessionCleanupError) + class _TestChannelNoPTY(ServerTestCase): """Unit tests for AsyncSSH channel module with PTYs disallowed""" @@ -1663,8 +1864,9 @@ async def test_dropbear_client(self): """Test reduced dropbear send packet size""" with patch('asyncssh.connection.SSHServerChannel', _ServerChannel): - async with self.connect(client_version='dropbear', - max_pktsize=32759) as conn: + async with self.connect( + client_version='dropbear', max_pktsize=32759, + compression_algs=['zlib@openssh.com']) as conn: _, stdout, _ = await conn.open_session('send_pktsize') self.assertEqual((await stdout.read()), '32758') @@ -1690,7 +1892,8 @@ async def test_dropbear_server(self): """Test reduced dropbear send packet size""" with patch('asyncssh.connection.SSHClientChannel', _ClientChannel): - async with self.connect() as conn: + async with self.connect( + compression_algs='zlib@openssh.com') as conn: stdin, _, _ = await conn.open_session() self.assertEqual(stdin.channel.get_send_pktsize(), 32758) diff --git a/tests/test_config.py b/tests/test_config.py index 31f055e..b229d3a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022 by Ron Frederick and others. +# Copyright (c) 2020-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -71,7 +71,7 @@ def test_set_bool(self): for value, result in (('yes', True), ('true', True), ('no', False), ('false', False)): - config = self._parse_config('Compression %s' % value) + config = self._parse_config(f'Compression {value}') self.assertEqual(config.get('Compression'), result) config = self._parse_config('Compression yes\nCompression no') @@ -101,13 +101,26 @@ def test_set_address_family(self): for family, result in (('any', socket.AF_UNSPEC), ('inet', socket.AF_INET), ('inet6', socket.AF_INET6)): - config = self._parse_config('AddressFamily %s' % family) + config = self._parse_config(f'AddressFamily {family}') self.assertEqual(config.get('AddressFamily'), result) config = self._parse_config('AddressFamily inet\n' 'AddressFamily inet6') self.assertEqual(config.get('AddressFamily'), socket.AF_INET) + def test_set_canonicaize_host(self): + """Test canonicalize host config option""" + + for value, result in (('yes', True), ('true', True), + ('no', False), ('false', False), + ('always', 'always')): + config = self._parse_config(f'CanonicalizeHostname {value}') + self.assertEqual(config.get('CanonicalizeHostname'), result) + + config = self._parse_config('CanonicalizeHostname yes\n' + 'CanonicalizeHostname no') + self.assertEqual(config.get('CanonicalizeHostname'), True) + def test_set_rekey_limit(self): """Test rekey limit config option""" @@ -117,7 +130,7 @@ def test_set_rekey_limit(self): ('default', ((), ())), ('default 2', ((), '2')), ('default none', ((), None))): - config = self._parse_config('RekeyLimit %s' % value) + config = self._parse_config(f'RekeyLimit {value}') self.assertEqual(config.get('RekeyLimit'), result) config = self._parse_config('RekeyLimit 1 2\nRekeyLimit 3 4') @@ -143,8 +156,8 @@ def test_include(self): with open('.ssh/include', 'w') as f: f.write('Port 2222') - for path in ('include', Path('.ssh/include').absolute().as_posix()): - config = self._parse_config('Include %s' % path) + for path in ('include', Path('.ssh/include').resolve().as_posix()): + config = self._parse_config(f'Include {path}') self.assertEqual(config.get('Port'), 2222) def test_missing_include(self): @@ -182,6 +195,24 @@ def test_match_all(self): config = self._parse_config('Match user xxx\nMatch all\nPort 2222') self.assertEqual(config.get('Port'), 2222) + def test_match_negated(self): + """Test a match block which never matches due to negation""" + + config = self._parse_config('Match !all user xxx\nPort 2222') + self.assertEqual(config.get('Port'), None) + + def test_match_canonical(self): + """Test a match block which matches when the host is canonicalized""" + + config = self._parse_config('Match canonical\nPort 2222') + self.assertEqual(config.get('Port'), None) + + def test_match_final(self): + """Test a match block which matches on the final parsing pass""" + + config = self._parse_config('Match final\nPort 2222') + self.assertEqual(config.get('Port'), None) + def test_match_exec(self): """Test a match block which runs a subprocess""" @@ -214,7 +245,7 @@ def test_equals(self): """Test config option with equals instead of space""" for delimiter in ('=', ' =', '= ', ' = '): - config = self._parse_config('Compression%syes' % delimiter) + config = self._parse_config(f'Compression{delimiter}yes') self.assertEqual(config.get('Compression'), True) def test_unknown(self): @@ -231,6 +262,8 @@ def test_errors(self): ('Unbalanced quotes', 'BindAddress "foo'), ('Extra data at end', 'BindAddress foo bar'), ('Invalid address family', 'AddressFamily xxx'), + ('Invalid canonicalization option', + 'CanonicalizeHostname xxx'), ('Invalid boolean', 'Compression xxx'), ('Invalid integer', 'Port xxx'), ('Invalid match condition', 'Match xxx')): @@ -243,13 +276,14 @@ class _TestClientConfig(_TestConfig): """Unit tests for client config objects""" def _load_config(self, config, last_config=None, reload=False, - local_user='user', user=(), host='host', port=()): + canonical=False, final=False, local_user='user', + user=(), host='host', port=()): """Load a client configuration""" # pylint: disable=arguments-differ - return SSHClientConfig.load(last_config, config, reload, - local_user, user, host, port) + return SSHClientConfig.load(last_config, config, reload, canonical, + final, local_user, user, host, port) def test_set_string_none(self): """Test string config option""" @@ -279,25 +313,47 @@ def test_set_string_list(self): 'UserKnownHostsFile file2') self.assertEqual(config.get('UserKnownHostsFile'), ['file1']) + config = self._parse_config('UserKnownHostsFile none\n' + 'UserKnownHostsFile file2') + self.assertEqual(config.get('UserKnownHostsFile'), []) + def test_append_string_list(self): """Test appending multiple string config options to a list""" - config = self._parse_config('SendEnv foo\nSendEnv bar baz') + config = self._parse_config('SendEnv foo\nSendEnv bar baz') self.assertEqual(config.get('SendEnv'), ['foo', 'bar', 'baz']) + def test_set_environment(self): + """Test setting environment with equals-separated key/value pairs""" + + config = self._parse_config('SetEnv A=1 B= C=D=2\nSetEnv E=3') + self.assertEqual(config.get('SetEnv'), ['A=1', 'B=', 'C=D=2']) + def test_set_remote_command(self): """Test setting a remote command""" config = self._parse_config(' RemoteCommand foo bar baz') self.assertEqual(config.get('RemoteCommand'), 'foo bar baz') + def test_set_forward_agent(self): + """Test agent forwarding path config option""" + + for value, result in (('yes', True), ('true', True), + ('no', False), ('false', False), + ('agent', 'agent'), ('%d/agent', './agent')): + config = self._parse_config(f'ForwardAgent {value}') + self.assertEqual(config.get('ForwardAgent'), result) + + config = self._parse_config('ForwardAgent yes\nForwardAgent no') + self.assertEqual(config.get('ForwardAgent'), True) + def test_set_request_tty(self): """Test pseudo-terminal request config option""" for value, result in (('yes', True), ('true', True), ('no', False), ('false', False), ('force', 'force'), ('auto', 'auto')): - config = self._parse_config('RequestTTY %s' % value) + config = self._parse_config(f'RequestTTY {value}') self.assertEqual(config.get('RequestTTY'), result) config = self._parse_config('RequestTTY yes\nRequestTTY no') @@ -345,6 +401,17 @@ def test_set_and_match_user(self): self.assertEqual(config.get('BindAddress'), 'addr') self.assertEqual(config.get('Port'), 2222) + def test_tag(self): + """Test setting and matching a tag""" + + config = self._parse_config('Tag tag2\n' + 'Match tagged tag1\n' + ' Port 1111\n' + 'Match tagged tag*\n' + ' Port 2222') + + self.assertEqual(config.get('Port'), 2222) + def test_port_already_set(self): """Test that port is ignored if set outside of the config""" @@ -417,8 +484,14 @@ def mock_expanduser(path): return path + def mock_pathlib_expanduser(self): + """Expand user even with os.path.expanduser mocked out""" + + return Path(os.environ['HOME'], *self.parts[1:]) + with self.assertRaises(asyncssh.ConfigParseError): - with patch('os.path.expanduser', mock_expanduser): + with patch('os.path.expanduser', mock_expanduser), \ + patch('pathlib.Path.expanduser', mock_pathlib_expanduser): self._parse_config('RemoteCommand %d') def test_uid_percent_expansion_unavailable(self): @@ -429,7 +502,7 @@ def test_uid_percent_expansion_unavailable(self): def mock_hasattr(obj, attr): if obj == os and attr == 'getuid': return False - else: + else: # pragma: no cover return orig_hasattr(obj, attr) with self.assertRaises(asyncssh.ConfigParseError): @@ -441,24 +514,39 @@ def test_invalid_percent_expansion(self): for desc, config_data in ( ('Bad token in hostname', 'Hostname %p'), - ('Invalid token', 'IdentityFile %x'), - ('Percent at end', 'IdentityFile %')): + ('Invalid token', 'IdentityFile %x')): with self.subTest(desc): with self.assertRaises(asyncssh.ConfigParseError): self._parse_config(config_data) + def test_env_expansion(self): + """Test environment variable expansion""" + + config = self._parse_config('RemoteCommand ${HOME}/.ssh') + + self.assertEqual(config.get('RemoteCommand'), './.ssh') + + def test_invalid_env_expansion(self): + """Test invalid environment variable expansion""" + + with self.assertRaises(asyncssh.ConfigParseError): + self._parse_config('RemoteCommand ${XXX}') + + class _TestServerConfig(_TestConfig): """Unit tests for server config objects""" def _load_config(self, config, last_config=None, reload=False, + canonical=False, final=False, local_addr='127.0.0.1', local_port=22, user='user', host=None, addr='127.0.0.1'): """Load a server configuration""" # pylint: disable=arguments-differ - return SSHServerConfig.load(last_config, config, reload, - local_addr, local_port, user, host, addr) + return SSHServerConfig.load(last_config, config, reload, canonical, + final, local_addr, local_port, user, + host, addr) def test_match_local_address(self): """Test matching on local address""" diff --git a/tests/test_connection.py b/tests/test_connection.py index 9a3871c..e567426 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016-2022 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -25,17 +25,18 @@ import os from pathlib import Path import socket -import subprocess import sys import unittest from unittest.mock import patch import asyncssh -from asyncssh.constants import MSG_UNIMPLEMENTED, MSG_DEBUG +from asyncssh.constants import MSG_IGNORE, MSG_DEBUG from asyncssh.constants import MSG_SERVICE_REQUEST, MSG_SERVICE_ACCEPT from asyncssh.constants import MSG_KEXINIT, MSG_NEWKEYS +from asyncssh.constants import MSG_KEX_FIRST, MSG_KEX_LAST from asyncssh.constants import MSG_USERAUTH_REQUEST, MSG_USERAUTH_SUCCESS from asyncssh.constants import MSG_USERAUTH_FAILURE, MSG_USERAUTH_BANNER +from asyncssh.constants import MSG_USERAUTH_FIRST from asyncssh.constants import MSG_GLOBAL_REQUEST from asyncssh.constants import MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_CONFIRMATION from asyncssh.constants import MSG_CHANNEL_OPEN_FAILURE, MSG_CHANNEL_DATA @@ -43,23 +44,18 @@ from asyncssh.crypto.cipher import GCMCipher from asyncssh.encryption import get_encryption_algs from asyncssh.kex import get_kex_algs +from asyncssh.kex_dh import MSG_KEX_ECDH_REPLY from asyncssh.mac import _HMAC, _mac_handler, get_mac_algs -from asyncssh.packet import Boolean, NameList, String, UInt32 +from asyncssh.packet import SSHPacket, Boolean, NameList, String, UInt32 from asyncssh.public_key import get_default_public_key_algs from asyncssh.public_key import get_default_certificate_algs from asyncssh.public_key import get_default_x509_certificate_algs from .server import Server, ServerTestCase -from .util import asynctest, gss_available, patch_gss, run -from .util import patch_getnameinfo, x509_available - - -try: - run('which nc') - _nc_available = True -except subprocess.CalledProcessError: # pragma: no cover - _nc_available = False +from .util import asynctest, patch_extra_kex, patch_getaddrinfo +from .util import patch_getnameinfo, patch_gss +from .util import gss_available, nc_available, x509_available class _CheckAlgsClientConnection(asyncssh.SSHClientConnection): @@ -178,6 +174,27 @@ def _send_ext_info(self): super()._send_ext_info() +class _BadSignatureServerConnection(asyncssh.SSHServerConnection): + """Test returning a bad signature in host keys prove request""" + + def _process_hostkeys_prove_00_at_openssh_dot_com_global_request( + self, packet): + """Prove the server has private keys for all requested host keys""" + + self._report_global_response(String(b'')) + + +class _ProveFailedServerConnection(asyncssh.SSHServerConnection): + """Test returning failure in host keys prove request""" + + def _process_hostkeys_prove_00_at_openssh_dot_com_global_request( + self, packet): + """Prove the server has private keys for all requested host keys""" + + super()._process_hostkeys_prove_00_at_openssh_dot_com_global_request( + SSHPacket(String(b''))) + + def _failing_get_mac(alg, key): """Replace HMAC class with FailingMAC""" @@ -193,6 +210,12 @@ def verify(self, seq, packet, sig): return _FailingMAC(key, hash_size, *args) +async def _slow_connect(*_args, **_kwargs): + """Simulate a really slow connect that ends up timing out""" + + await asyncio.sleep(5) + + class _FailingGCMCipher(GCMCipher): """Test error in GCM tag verification""" @@ -259,6 +282,17 @@ def connection_made(self, conn): raise RuntimeError('Exception handler test') +class _ClientCleanupError(asyncssh.SSHClient): + """Test of exception during client cleanup""" + + def connection_lost(self, exc): + """Raise an error when a client is cleaned up""" + + # pylint: disable=unused-argument + + raise RuntimeError('Exception in cleanup test') + + class _TunnelServer(Server): """Allow forwarding to test server host key request tunneling""" @@ -339,14 +373,6 @@ def begin_auth(self, username): return False -def disconnect_on_unimplemented(self, pkttype, pktid, packet): - """Process an unimplemented message response""" - - # pylint: disable=unused-argument - - self.disconnect(asyncssh.DISC_BY_APPLICATION, 'Unexpected response') - - @patch_gss @patch('asyncssh.connection.SSHClientConnection', _CheckAlgsClientConnection) class _TestConnection(ServerTestCase): @@ -388,9 +414,71 @@ async def _check_version(self, *args, **kwargs): async def test_connect(self): """Test connecting with async context manager""" - async with self.connect(): + async with self.connect() as conn: + pass + + self.assertTrue(conn.is_closed()) + + @asynctest + async def test_connect_sock(self): + """Test connecting using an already-connected socket""" + + sock = socket.socket() + await self.loop.sock_connect(sock, (self._server_addr, + self._server_port)) + + async with asyncssh.connect(sock=sock): + pass + + @unittest.skipUnless(nc_available, 'Netcat not available') + @asynctest + async def test_connect_non_tcp_sock(self): + """Test connecting using an non-TCP socket""" + + sock1, sock2 = socket.socketpair() + + proc = await asyncio.create_subprocess_exec( + 'nc', str(self._server_addr), str(self._server_port), + stdin=sock1, stdout=sock1, stderr=sock1) + + async with asyncssh.connect( + self._server_addr, self._server_port, sock=sock2): + pass + + await proc.wait() + sock1.close() + + @asynctest + async def test_run_client(self): + """Test running an SSH client on an already-connected socket""" + + sock = socket.socket() + await self.loop.sock_connect(sock, (self._server_addr, + self._server_port)) + + async with self.run_client(sock): + pass + + @asynctest + async def test_connect_encrypted_key(self): + """Test connecting with encrypted client key and no passphrase""" + + async with self.connect(client_keys='ckey_encrypted', + ignore_encrypted=True): pass + with self.assertRaises(asyncssh.KeyImportError): + await self.connect(client_keys='ckey_encrypted') + + with open('config', 'w') as f: + f.write('IdentityFile ckey_encrypted') + + async with self.connect(config='config'): + pass + + with self.assertRaises(asyncssh.KeyImportError): + await self.connect(config='config', ignore_encrypted=False) + @asynctest async def test_connect_invalid_options_type(self): """Test connecting using options using incorrect type of options""" @@ -426,22 +514,28 @@ async def test_connect_timeout_exceeded(self): """Test connect timeout exceeded""" with self.assertRaises(asyncio.TimeoutError): - await asyncssh.connect('223.255.255.254', connect_timeout=1) + with patch('asyncio.BaseEventLoop.create_connection', + _slow_connect): + await asyncssh.connect('', connect_timeout=1) @asynctest async def test_connect_timeout_exceeded_string(self): """Test connect timeout exceeded with string value""" with self.assertRaises(asyncio.TimeoutError): - await asyncssh.connect('223.255.255.254', connect_timeout='0m1s') + with patch('asyncio.BaseEventLoop.create_connection', + _slow_connect): + await asyncssh.connect('', connect_timeout='0m1s') @asynctest async def test_connect_timeout_exceeded_tunnel(self): """Test connect timeout exceeded""" with self.assertRaises(asyncio.TimeoutError): - await asyncssh.listen(server_host_keys=['skey'], - tunnel='223.255.255.254', connect_timeout=1) + with patch('asyncio.BaseEventLoop.create_connection', + _slow_connect): + await asyncssh.listen(server_host_keys=['skey'], + tunnel='', connect_timeout=1) @asynctest async def test_invalid_connect_timeout(self): @@ -529,6 +623,14 @@ async def test_duplicate_type_server_host_keys(self): with self.assertRaises(ValueError): await asyncssh.listen(server_host_keys=['skey', 'skey']) + @asynctest + async def test_reserved_server_host_keys(self): + """Test reserved host keys with host key sending enabled""" + + async with self.listen(server_host_keys=['skey', 'skey'], + send_server_host_keys=True): + pass + @asynctest async def test_get_server_host_key(self): """Test retrieving a server host key""" @@ -555,7 +657,7 @@ async def test_get_server_host_key_connect_failure(self): with self.assertRaises(OSError): await asyncssh.get_server_host_key('\xff') - @unittest.skipUnless(_nc_available, 'Netcat not available') + @unittest.skipUnless(nc_available, 'Netcat not available') @asynctest async def test_get_server_host_key_proxy(self): """Test retrieving a server host key using proxy command""" @@ -567,12 +669,12 @@ async def test_get_server_host_key_proxy(self): self.assertEqual(key, keylist[0]) - @unittest.skipUnless(_nc_available, 'Netcat not available') + @unittest.skipUnless(nc_available, 'Netcat not available') @asynctest async def test_get_server_host_key_proxy_failure(self): """Test failure retrieving a server host key using proxy command""" - # Leave out arguments to 'nc' to trigger a faliure + # Leave out arguments to 'nc' to trigger a failure proxy_command = 'nc' with self.assertRaises((OSError, asyncssh.ConnectionLost)): @@ -616,6 +718,16 @@ async def test_known_hosts_none(self): async with self.connect(known_hosts=None) as conn: self.assertEqual(conn.get_server_host_key_algs(), default_algs) + @asynctest + async def test_known_hosts_none_in_config(self): + """Test connecting with known hosts checking disabled in config file""" + + with open('config', 'w') as f: + f.write('UserKnownHostsFile none') + + async with self.connect(config='config'): + pass + @asynctest async def test_known_hosts_none_without_x509(self): """Test connecting with known hosts checking and X.509 disabled""" @@ -698,7 +810,7 @@ async def test_import_known_hosts(self): known_hosts_path = os.path.join('.ssh', 'known_hosts') - with open(known_hosts_path, 'r') as f: + with open(known_hosts_path) as f: known_hosts = asyncssh.import_known_hosts(f.read()) async with self.connect(known_hosts=known_hosts): @@ -812,7 +924,7 @@ async def test_kex_algs(self): @asynctest async def test_duplicate_encryption_algs(self): - """Test connecting with an duplicated encryption algorithm""" + """Test connecting with a duplicated encryption algorithm""" with patch('asyncssh.connection.SSHClientConnection', _CheckAlgsClientConnection): @@ -891,22 +1003,6 @@ def unsupported_kex_alg(): with self.assertRaises(asyncssh.KeyExchangeFailed): await self.connect(kex_algs=['fail']) - @asynctest - async def test_skip_ext_info(self): - """Test not requesting extension info from the server""" - - def skip_ext_info(self): - """Don't request extension information""" - - # pylint: disable=unused-argument - - return [] - - with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg', - skip_ext_info): - async with self.connect(): - pass - @asynctest async def test_unknown_ext_info(self): """Test receiving unknown extension information""" @@ -928,8 +1024,56 @@ def send_newkeys(self, k, h): with patch('asyncssh.connection.SSHClientConnection.send_newkeys', send_newkeys): - async with self.connect(): - pass + with self.assertRaises((ConnectionError, asyncssh.ProtocolError)): + await self.connect() + + @asynctest + async def test_message_before_kexinit_strict_kex(self): + """Test receiving a message before KEXINIT with strict_kex enabled""" + + def send_packet(self, pkttype, *args, **kwargs): + if pkttype == MSG_KEXINIT: + self.send_packet(MSG_IGNORE, String(b'')) + + asyncssh.connection.SSHConnection.send_packet( + self, pkttype, *args, **kwargs) + + with patch('asyncssh.connection.SSHClientConnection.send_packet', + send_packet): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() + + @asynctest + async def test_message_during_kex_strict_kex(self): + """Test receiving an unexpected message with strict_kex enabled""" + + def send_packet(self, pkttype, *args, **kwargs): + if pkttype == MSG_KEX_ECDH_REPLY: + self.send_packet(MSG_IGNORE, String(b'')) + + asyncssh.connection.SSHConnection.send_packet( + self, pkttype, *args, **kwargs) + + with patch('asyncssh.connection.SSHServerConnection.send_packet', + send_packet): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() + + @asynctest + async def test_unknown_message_during_kex_strict_kex(self): + """Test receiving an unknown message with strict_kex enabled""" + + def send_packet(self, pkttype, *args, **kwargs): + if pkttype == MSG_KEX_ECDH_REPLY: + self.send_packet(MSG_KEX_LAST) + + asyncssh.connection.SSHConnection.send_packet( + self, pkttype, *args, **kwargs) + + with patch('asyncssh.connection.SSHServerConnection.send_packet', + send_packet): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() @asynctest async def test_encryption_algs(self): @@ -1055,21 +1199,107 @@ async def test_invalid_debug(self): await conn.wait_closed() @asynctest - async def test_invalid_service_request(self): - """Test invalid service request""" + async def test_service_request_before_kex_complete(self): + """Test service request before kex is complete""" + + def send_newkeys(self, k, h): + """Finish a key exchange and send a new keys message""" + + self._kex_complete = True + + self.send_packet(MSG_SERVICE_REQUEST, String('ssh-userauth')) + + asyncssh.connection.SSHConnection.send_newkeys(self, k, h) + + with patch('asyncssh.connection.SSHClientConnection.send_newkeys', + send_newkeys): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() + + @asynctest + async def test_service_accept_before_kex_complete(self): + """Test service accept before kex is complete""" + + def send_newkeys(self, k, h): + """Finish a key exchange and send a new keys message""" + + self._kex_complete = True + + self.send_packet(MSG_SERVICE_ACCEPT, String('ssh-userauth')) + + asyncssh.connection.SSHConnection.send_newkeys(self, k, h) + + with patch('asyncssh.connection.SSHServerConnection.send_newkeys', + send_newkeys): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() + + @asynctest + async def test_unexpected_service_name_in_request(self): + """Test unexpected service name in service request""" conn = await self.connect() conn.send_packet(MSG_SERVICE_REQUEST, String('xxx')) await conn.wait_closed() @asynctest - async def test_invalid_service_accept(self): - """Test invalid service accept""" + async def test_unexpected_service_name_in_accept(self): + """Test unexpected service name in accept sent by server""" + + def send_newkeys(self, k, h): + """Finish a key exchange and send a new keys message""" + + asyncssh.connection.SSHConnection.send_newkeys(self, k, h) + + self.send_packet(MSG_SERVICE_ACCEPT, String('xxx')) + + with patch('asyncssh.connection.SSHServerConnection.send_newkeys', + send_newkeys): + with self.assertRaises(asyncssh.ServiceNotAvailable): + await self.connect() + + @asynctest + async def test_service_accept_from_client(self): + """Test service accept sent by client""" conn = await self.connect() - conn.send_packet(MSG_SERVICE_ACCEPT, String('xxx')) + conn.send_packet(MSG_SERVICE_ACCEPT, String('ssh-userauth')) await conn.wait_closed() + @asynctest + async def test_service_request_from_server(self): + """Test service request sent by server""" + + def send_newkeys(self, k, h): + """Finish a key exchange and send a new keys message""" + + asyncssh.connection.SSHConnection.send_newkeys(self, k, h) + + self.send_packet(MSG_SERVICE_REQUEST, String('ssh-userauth')) + + with patch('asyncssh.connection.SSHServerConnection.send_newkeys', + send_newkeys): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() + + @asynctest + async def test_client_decompression_failure(self): + """Test client decompression failure""" + + def send_packet(self, pkttype, *args, **kwargs): + """Send an SSH packet""" + + asyncssh.connection.SSHConnection.send_packet( + self, pkttype, *args, **kwargs) + + if pkttype == MSG_USERAUTH_SUCCESS: + self._compressor = None + self.send_debug('Test') + + with patch('asyncssh.connection.SSHServerConnection.send_packet', + send_packet): + await self.connect(compression_algs=['zlib@openssh.com']) + @asynctest async def test_packet_decode_error(self): """Test SSH packet decode error""" @@ -1276,6 +1506,41 @@ async def test_invalid_newkeys(self): conn.send_packet(MSG_NEWKEYS) await conn.wait_closed() + @asynctest + async def test_kex_after_kex_complete(self): + """Test kex request when kex not in progress""" + + conn = await self.connect() + conn.send_packet(MSG_KEX_FIRST) + await conn.wait_closed() + + @asynctest + async def test_userauth_after_auth_complete(self): + """Test userauth request when auth not in progress""" + + conn = await self.connect() + conn.send_packet(MSG_USERAUTH_FIRST) + await conn.wait_closed() + + @asynctest + async def test_userauth_before_kex_complete(self): + """Test receiving userauth before kex is complete""" + + def send_newkeys(self, k, h): + """Finish a key exchange and send a new keys message""" + + self._kex_complete = True + + self.send_packet(MSG_USERAUTH_REQUEST, String('guest'), + String('ssh-connection'), String('none')) + + asyncssh.connection.SSHConnection.send_newkeys(self, k, h) + + with patch('asyncssh.connection.SSHClientConnection.send_newkeys', + send_newkeys): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect() + @asynctest async def test_invalid_userauth_service(self): """Test invalid service in userauth request""" @@ -1325,25 +1590,32 @@ async def test_extra_userauth_request(self): String('ssh-connection'), String('none')) await asyncio.sleep(0.1) + @asynctest + async def test_late_userauth_request(self): + """Test userauth request after auth is final""" + + async with self.connect() as conn: + conn.send_packet(MSG_GLOBAL_REQUEST, String('xxx'), + Boolean(False)) + conn.send_packet(MSG_USERAUTH_REQUEST, String('guest'), + String('ssh-connection'), String('none')) + await conn.wait_closed() + @asynctest async def test_unexpected_userauth_success(self): """Test unexpected userauth success response""" - with patch.dict('asyncssh.connection.SSHConnection._packet_handlers', - {MSG_UNIMPLEMENTED: disconnect_on_unimplemented}): - conn = await self.connect() - conn.send_packet(MSG_USERAUTH_SUCCESS) - await conn.wait_closed() + conn = await self.connect() + conn.send_packet(MSG_USERAUTH_SUCCESS) + await conn.wait_closed() @asynctest async def test_unexpected_userauth_failure(self): """Test unexpected userauth failure response""" - with patch.dict('asyncssh.connection.SSHConnection._packet_handlers', - {MSG_UNIMPLEMENTED: disconnect_on_unimplemented}): - conn = await self.connect() - conn.send_packet(MSG_USERAUTH_FAILURE, NameList([]), Boolean(False)) - await conn.wait_closed() + conn = await self.connect() + conn.send_packet(MSG_USERAUTH_FAILURE, NameList([]), Boolean(False)) + await conn.wait_closed() @asynctest async def test_unexpected_userauth_banner(self): @@ -1458,6 +1730,210 @@ async def test_internal_error(self): with self.assertRaises(RuntimeError): await self.create_connection(_InternalErrorClient) + @asynctest + async def test_client_cleanup_error(self): + """Test error in client cleanup""" + + async with self.connect(client_factory=_ClientCleanupError): + pass + + +@patch_extra_kex +class _TestConnectionNoStrictKex(ServerTestCase): + """Unit tests for connection API with ext info and strict kex disabled""" + + @classmethod + async def start_server(cls): + """Start an SSH server to connect to""" + + return (await cls.create_server(_TunnelServer, gss_host=(), + compression_algs='*', + encryption_algs='*', + kex_algs='*', mac_algs='*')) + + @asynctest + async def test_skip_ext_info(self): + """Test not requesting extension info from the server""" + + async with self.connect(): + pass + + @asynctest + async def test_message_before_kexinit(self): + """Test receiving a message before KEXINIT""" + + def send_packet(self, pkttype, *args, **kwargs): + if pkttype == MSG_KEXINIT: + self.send_packet(MSG_IGNORE, String(b'')) + + asyncssh.connection.SSHConnection.send_packet( + self, pkttype, *args, **kwargs) + + with patch('asyncssh.connection.SSHClientConnection.send_packet', + send_packet): + async with self.connect(): + pass + + @asynctest + async def test_message_during_kex(self): + """Test receiving an unexpected message in key exchange""" + + def send_packet(self, pkttype, *args, **kwargs): + if pkttype == MSG_KEX_ECDH_REPLY: + self.send_packet(MSG_IGNORE, String(b'')) + + asyncssh.connection.SSHConnection.send_packet( + self, pkttype, *args, **kwargs) + + with patch('asyncssh.connection.SSHServerConnection.send_packet', + send_packet): + async with self.connect(): + pass + + @asynctest + async def test_sequence_wrap_during_kex(self): + """Test sequence wrap during initial key exchange""" + + def send_packet(self, pkttype, *args, **kwargs): + if pkttype == MSG_KEXINIT: + if self._options.command == 'send': + self._send_seq = 0xfffffffe + else: + self._recv_seq = 0xfffffffe + + asyncssh.connection.SSHConnection.send_packet( + self, pkttype, *args, **kwargs) + + with patch('asyncssh.connection.SSHClientConnection.send_packet', + send_packet): + with self.assertRaises(asyncssh.ProtocolError): + await self.connect(command='send') + + with self.assertRaises(asyncssh.ProtocolError): + await self.connect(command='recv') + + +class _TestConnectionHostKeysHandler(ServerTestCase): + """Unit test for specifying a host keys handler""" + + @classmethod + async def start_server(cls): + """Start an SSH server to connect to""" + + return (await cls.create_server( + server_host_keys=['skey', 'skey_ecdsa'], + send_server_host_keys=True)) + + async def _check_host_keys(self, host_keys, known_hosts, expected): + """Check server host keys handler""" + + def host_keys_handler(*results): + """Check reported host keys against expected value""" + + self.assertEqual([len(r) for r in results], expected) + conn.close() + + async def async_host_keys_handler(*results): + """Check async version of server host keys handler""" + + host_keys_handler(*results) + + self._server.update(server_host_keys=host_keys) + + conn = await self.connect(server_host_keys_handler=host_keys_handler, + known_hosts=known_hosts) + + if expected is None: + await asyncio.sleep(0.1) + conn.close() + + await conn.wait_closed() + + if expected: + conn = await self.connect( + server_host_keys_handler=async_host_keys_handler, + known_hosts=known_hosts) + + await conn.wait_closed() + + + @asynctest + async def test_host_key_handler_disabled(self): + """Test server host keys handler being disabled""" + + async with self.connect(): + await asyncio.sleep(0.1) + + @asynctest + async def test_host_key_added(self): + """Test server host keys handler showing a key added""" + + await self._check_host_keys(['skey', 'skey_ecdsa'], + [['skey'], [], []], + [1, 0, 1, 0]) + + @asynctest + async def test_host_key_removed(self): + """Test server host keys handler showing a key removed""" + + await self._check_host_keys(['skey'], [['skey', 'skey_ecdsa'], [], []], + [0, 1, 1, 0]) + + @asynctest + async def test_host_key_revoked(self): + """Test server host keys handler showing a key revoked""" + + await self._check_host_keys(['skey', 'skey_ecdsa'], + [['skey'], [], ['skey_ecdsa']], + [0, 0, 1, 1]) + + @asynctest + async def test_no_trusted_hosts(self): + """Test server host keys handler is disabled due to no trusted hosts""" + + await self._check_host_keys(['skey'], None, None) + + @asynctest + async def test_host_key_bad_signature(self): + """Test server host keys handler getting back a bad signature""" + + with patch('asyncssh.connection.SSHServerConnection', + _BadSignatureServerConnection): + await self._check_host_keys(['skey', 'skey_ecdsa'], + [['skey'], [], []], + [0, 0, 1, 0]) + + @asynctest + async def test_host_key_prove_failed(self): + """Test server host keys handler getting back a prove failure""" + + with patch('asyncssh.connection.SSHServerConnection', + _ProveFailedServerConnection): + await self._check_host_keys(['skey', 'skey_ecdsa'], + [['skey'], [], []], + [0, 0, 1, 0]) + + +class _TestConnectionListenSock(ServerTestCase): + """Unit test for specifying a listen socket""" + + @classmethod + async def start_server(cls): + """Start an SSH server to connect to""" + + sock = socket.socket() + sock.bind(('', 0)) + + return await cls.create_server(_TunnelServer, sock=sock) + + @asynctest + async def test_connect(self): + """Test specifying explicit listen sock""" + + with self.assertLogs(level='INFO'): + async with self.connect(): + pass + class _TestConnectionAsyncAcceptor(ServerTestCase): """Unit test for async acceptor""" @@ -1527,7 +2003,29 @@ async def test_connect_reverse(self): async with self.connect_reverse(): pass - @unittest.skipUnless(_nc_available, 'Netcat not available') + @asynctest + async def test_connect_reverse_sock(self): + """Test reverse connection using an already-connected socket""" + + sock = socket.socket() + await self.loop.sock_connect(sock, (self._server_addr, + self._server_port)) + + async with self.connect_reverse(sock=sock): + pass + + @asynctest + async def test_run_server(self): + """Test running an SSH server on an already-connected socket""" + + sock = socket.socket() + await self.loop.sock_connect(sock, (self._server_addr, + self._server_port)) + + async with self.run_server(sock): + pass + + @unittest.skipUnless(nc_available, 'Netcat not available') @asynctest async def test_connect_reverse_proxy(self): """Test reverse direction SSH connection with proxy command""" @@ -1691,14 +2189,14 @@ async def test_connect_x509_self(self): @asynctest async def test_connect_x509_untrusted_self(self): - """Test connecting with untrusted X.509 self-signed certficate""" + """Test connecting with untrusted X.509 self-signed certificate""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(x509_trusted_certs='root_ca_cert.pem') @asynctest async def test_connect_x509_revoked_self(self): - """Test connecting with revoked X.509 self-signed certficate""" + """Test connecting with revoked X.509 self-signed certificate""" with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect(known_hosts=([], [], [], ['root_ca_cert.pem'], @@ -1894,11 +2392,9 @@ def client_factory(): return _ValidateHostKeyClient(host_key='skey.pub') - algs = [asyncssh.read_public_key('skey.pub').get_algorithm()] - - conn, _ = await self.create_connection(client_factory, - known_hosts=([], [], []), - server_host_key_algs=algs) + conn, _ = await self.create_connection( + client_factory, known_hosts=([], [], []), + server_host_key_algs=['rsa-sha2-256']) async with conn: pass @@ -1980,6 +2476,13 @@ async def test_host_key_mismatch(self): with self.assertRaises(asyncssh.HostKeyNotVerifiable): await self.connect() + @asynctest + async def test_host_key_unknown(self): + """Test unknown host key alias""" + + with self.assertRaises(asyncssh.HostKeyNotVerifiable): + await self.connect(host_key_alias='unknown') + @asynctest async def test_host_key_match(self): """Test host key match""" @@ -2182,9 +2685,129 @@ async def test_ssh_listen_context_manager(self): """Test using an SSH listener as a context manager""" async with self.listen() as server: - listen_port = server.sockets[0].getsockname()[1] - + listen_port = server.get_port() async with asyncssh.connect('127.0.0.1', listen_port, known_hosts=(['skey.pub'], [], [])): pass + + +@patch_getaddrinfo +class _TestCanonicalizeHost(ServerTestCase): + """Test hostname canonicalization""" + + @classmethod + async def start_server(cls): + """Start an SSH server to connect to""" + + return await cls.create_server(_TunnelServer) + + @asynctest + async def test_canonicalize(self): + """Test hostname canonicalization""" + + async with self.connect('testhost', known_hosts=None, + canonicalize_hostname=True, + canonical_domains=['test']) as conn: + self.assertEqual(conn.get_extra_info('host'), 'testhost.test') + + @asynctest + async def test_canonicalize_max_dots(self): + """Test hostname canonicalization exceeding max_dots""" + + async with self.connect('testhost.test', known_hosts=None, + canonicalize_hostname=True, + canonicalize_max_dots=0, + canonical_domains=['test']) as conn: + self.assertEqual(conn.get_extra_info('host'), 'testhost.test') + + @asynctest + async def test_canonicalize_ip_address(self): + """Test hostname canonicalization with IP address""" + + async with self.connect('127.0.0.1', known_hosts=None, + canonicalize_hostname=True, + canonicalize_max_dots=3, + canonical_domains=['test']) as conn: + self.assertEqual(conn.get_extra_info('host'), '127.0.0.1') + + @asynctest + async def test_canonicalize_proxy(self): + """Test hostname canonicalization with proxy""" + + with open('config', 'w') as f: + f.write('UserKnownHostsFile none\n') + + async with self.connect('testhost', config='config', + tunnel=f'localhost:{self._server_port}', + canonicalize_hostname=True, + canonical_domains=['test']) as conn: + self.assertEqual(conn.get_extra_info('host'), 'testhost.test') + + @asynctest + async def test_canonicalize_always(self): + """Test hostname canonicalization for all connections""" + + with open('config', 'w') as f: + f.write('UserKnownHostsFile none\n') + + async with self.connect('testhost', config='config', + tunnel=f'localhost:{self._server_port}', + canonicalize_hostname='always', + canonical_domains=['test']) as conn: + self.assertEqual(conn.get_extra_info('host'), 'testhost.test') + + @asynctest + async def test_canonicalize_failure(self): + """Test hostname canonicalization failure""" + + with self.assertRaises(socket.gaierror): + await self.connect('unknown', known_hosts=(['skey.pub'], [], []), + canonicalize_hostname=True, + canonical_domains=['test']) + + @asynctest + async def test_canonicalize_failed_no_fallback(self): + """Test hostname canonicalization""" + + with self.assertRaises(OSError): + await self.connect('unknown', known_hosts=(['skey.pub'], [], []), + canonicalize_hostname=True, + canonical_domains=['test'], + canonicalize_fallback_local=False) + + @asynctest + async def test_cname_returned(self): + """Test hostname canonicalization with cname returned""" + + async with self.connect('testcname', + known_hosts=(['skey.pub'], [], []), + canonicalize_hostname=True, + canonical_domains=['test'], + canonicalize_permitted_cnames= \ + [('*.test', '*.test')]) as conn: + self.assertEqual(conn.get_extra_info('host'), 'cname.test') + + @asynctest + async def test_cname_not_returned(self): + """Test hostname canonicalization with cname not returned""" + + async with self.connect('testcname', + known_hosts=(['skey.pub'], [], []), + canonicalize_hostname=True, + canonical_domains=['test'], + canonicalize_permitted_cnames= \ + ['*.xxx:*.test']) as conn: + self.assertEqual(conn.get_extra_info('host'), 'testcname.test') + + @asynctest + async def test_bad_cname_rules(self): + """Test hostname canonicalization with bad cname rules""" + + with self.assertRaises(ValueError): + await self.connect('testcname', + known_hosts=(['skey.pub'], [], []), + canonicalize_hostname=True, + canonical_domains=['test'], + canonicalize_permitted_cnames= \ + ['*.xxx:*.test:*.xxx']) diff --git a/tests/test_connection_auth.py b/tests/test_connection_auth.py index a8f7615..4554d92 100644 --- a/tests/test_connection_auth.py +++ b/tests/test_connection_auth.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016-2022 by Ron Frederick and others. +# Copyright (c) 2016-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -22,10 +22,13 @@ import asyncio import os +import sys import unittest from unittest.mock import patch +from cryptography.exceptions import UnsupportedAlgorithm + import asyncssh from asyncssh.misc import async_context_manager, write_file from asyncssh.packet import String @@ -33,8 +36,9 @@ from .keysign_stub import create_subprocess_exec_stub from .server import Server, ServerTestCase -from .util import asynctest, gss_available, patch_getnameinfo, patch_gss -from .util import make_certificate, x509_available +from .util import asynctest, gss_available, patch_getnameinfo +from .util import patch_getnameinfo_error, patch_gss +from .util import make_certificate, nc_available, x509_available class _FailValidateHostSSHServerConnection(asyncssh.SSHServerConnection): @@ -71,6 +75,9 @@ async def begin_auth(self, username): return False + async def auth_completed(self): + """Handle client authentication request""" + class _HostBasedServer(Server): """Server for testing host-based authentication""" @@ -410,11 +417,11 @@ async def validate_kbdint_response(self, username, responses): class _UnknownAuthClientConnection(asyncssh.connection.SSHClientConnection): """Test getting back an unknown auth method from the SSH server""" - def try_next_auth(self): + def try_next_auth(self, *, next_method=False): """Attempt client authentication using an unknown method""" self._auth_methods = [b'unknown'] + self._auth_methods - super().try_next_auth() + super().try_next_auth(next_method=next_method) class _TestNullAuth(ServerTestCase): @@ -455,11 +462,13 @@ async def test_disabled_trivial_auth(self): class _TestGSSAuth(ServerTestCase): """Unit tests for GSS authentication""" + @unittest.skipIf(sys.platform == 'win32', 'skip GSS store test on Windows') @classmethod async def start_server(cls): """Start an SSH server which supports GSS authentication""" - return await cls.create_server(_AsyncGSSServer, gss_host='1') + return await cls.create_server(_AsyncGSSServer, gss_host='1', + gss_store='a') @asynctest async def test_get_server_auth_methods(self): @@ -486,6 +495,31 @@ async def test_gss_mic_auth(self): username='user', gss_host='1'): pass + @unittest.skipIf(sys.platform == 'win32', 'skip GSS store test on Windows') + @asynctest + async def test_gss_mic_auth_store(self): + """Test GSS MIC authentication with GSS store set""" + + async with self.connect(kex_algs=['ecdh-sha2-nistp256'], + username='user', gss_host='1', gss_store='a'): + pass + + @asynctest + async def test_gss_mic_auth_sign_error(self): + """Test GSS MIC authentication signing failure""" + + with self.assertRaises(asyncssh.PermissionDenied): + await self.connect(kex_algs=['ecdh-sha2-nistp256'], + username='user', gss_host='1,sign_error') + + @asynctest + async def test_gss_mic_auth_verify_error(self): + """Test GSS MIC authentication signature verification failure""" + + with self.assertRaises(asyncssh.PermissionDenied): + await self.connect(kex_algs=['ecdh-sha2-nistp256'], + username='user', gss_host='1,verify_error') + @asynctest async def test_gss_delegate(self): """Test GSS credential delegation""" @@ -642,6 +676,17 @@ async def test_get_server_auth_methods(self): self.assertEqual(auth_methods, ['hostbased']) + @unittest.skipUnless(nc_available, 'Netcat not available') + @asynctest + async def test_get_server_auth_methods_no_sockname(self): + """Test getting auth methods from the test server""" + + proxy_command = ('nc', str(self._server_addr), str(self._server_port)) + + with self.assertRaises(asyncssh.PermissionDenied): + await self.connect(username='user', client_host_keys='skey', + proxy_command=proxy_command) + @asynctest async def test_client_host_auth(self): """Test connecting with host-based authentication""" @@ -693,7 +738,7 @@ async def test_client_host_key_keypairs(self): async def test_client_host_signature_algs(self): """Test host based authentication with specific signature algorithms""" - for alg in ('ssh-rsa', 'rsa-sha2-256', 'rsa-sha2-512'): + for alg in ('rsa-sha2-256', 'rsa-sha2-512'): async with self.connect(username='user', client_host_keys='skey', client_username='user', signature_algs=[alg]): @@ -710,10 +755,14 @@ def skip_ext_info(self): return [] - with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg', + with patch('asyncssh.connection.SSHConnection._get_extra_kex_algs', skip_ext_info): - async with self.connect(username='user', client_host_keys='skey', - client_username='user'): + try: + async with self.connect(username='user', + client_host_keys='skey', + client_username='user'): + pass + except UnsupportedAlgorithm: # pragma: no cover pass @asynctest @@ -815,6 +864,26 @@ async def test_disabled_trivial_client_host_auth(self): disable_trivial_auth=True) +class _TestHostBasedAuthNoRDNS(ServerTestCase): + """Unit tests for host-based authentication with no reverse DNS""" + + @classmethod + async def start_server(cls): + """Start an SSH server which supports host-based authentication""" + + return await cls.create_server( + _HostBasedServer, known_client_hosts='known_hosts') + + @patch_getnameinfo_error + @asynctest + async def test_client_host_auth_no_rdns(self): + """Test connecting with host-based authentication with no RDNS""" + + async with self.connect(username='user', client_host_keys='skey', + client_username='user'): + pass + + @patch_getnameinfo class _TestCallbackHostBasedAuth(ServerTestCase): """Unit tests for host-based authentication using callback""" @@ -873,8 +942,8 @@ class _TestKeysignHostBasedAuth(ServerTestCase): async def start_server(cls): """Start an SSH server which supports host-based authentication""" - return await cls.create_server(_HostBasedServer, - known_client_hosts='known_hosts') + return await cls.create_server( + _HostBasedServer, known_client_hosts=(['skey_ecdsa.pub'], [], [])) @async_context_manager async def _connect_keysign(self, client_host_keysign=True, @@ -886,7 +955,7 @@ async def _connect_keysign(self, client_host_keysign=True, with patch('asyncssh.keysign._DEFAULT_KEYSIGN_DIRS', keysign_dirs): with patch('asyncssh.public_key._DEFAULT_HOST_KEY_DIRS', ['.']): with patch('asyncssh.public_key._DEFAULT_HOST_KEY_FILES', - ['skey', 'xxx']): + ['skey_ecdsa', 'xxx']): return await self.connect( username='user', client_host_keysign=client_host_keysign, @@ -911,7 +980,7 @@ async def test_explciit_keysign(self): async def test_keysign_explicit_host_keys(self): """Test ssh-keysign with explicit host public keys""" - async with self._connect_keysign(client_host_keys='skey.pub'): + async with self._connect_keysign(client_host_keys='skey_ecdsa.pub'): pass @asynctest @@ -1006,9 +1075,12 @@ async def test_mismatched_host_signature_algs(self): async def test_host_signature_alg_fallback(self): """Test fall back to default host key signature algorithm""" - async with self.connect(username='ckey', client_host_keys='skey', - client_username='user', - signature_algs=['rsa-sha2-256', 'ssh-rsa']): + try: + async with self.connect(username='ckey', client_host_keys='skey', + client_username='user', + signature_algs=['rsa-sha2-256', 'ssh-rsa']): + pass + except UnsupportedAlgorithm: # pragma: no cover pass @@ -1054,13 +1126,63 @@ async def test_encrypted_client_key(self): passphrase='passphrase'): pass + @asynctest + async def test_encrypted_client_key_callable(self): + """Test public key auth with callable passphrase""" + + def _passphrase(filename): + self.assertEqual(filename, 'ckey_encrypted') + return 'passphrase' + + async with self.connect(username='ckey', client_keys='ckey_encrypted', + agent_path=None, passphrase=_passphrase): + pass + + @asynctest + async def test_encrypted_client_key_awaitable(self): + """Test public key auth with awaitable passphrase""" + + async def _passphrase(filename): + self.assertEqual(filename, 'ckey_encrypted') + return 'passphrase' + + async with self.connect(username='ckey', client_keys='ckey_encrypted', + agent_path=None,passphrase=_passphrase): + pass + + @asynctest + async def test_encrypted_client_key_list_callable(self): + """Test public key auth with callable passphrase""" + + def _passphrase(filename): + self.assertEqual(filename, 'ckey_encrypted') + return 'passphrase' + + async with self.connect(username='ckey', + client_keys=['ckey_encrypted'], + agent_path=None, passphrase=_passphrase): + pass + + @asynctest + async def test_encrypted_client_key_list_awaitable(self): + """Test public key auth with awaitable passphrase""" + + async def _passphrase(filename): + self.assertEqual(filename, 'ckey_encrypted') + return 'passphrase' + + async with self.connect(username='ckey', + client_keys=['ckey_encrypted'], + agent_path=None, passphrase=_passphrase): + pass + @asynctest async def test_encrypted_client_key_bad_passphrase(self): """Test wrong passphrase for encrypted client key""" with self.assertRaises(asyncssh.KeyEncryptionError): await self.connect(username='ckey', client_keys='ckey_encrypted', - passphrase='xxx') + agent_path=None, passphrase='xxx') @asynctest async def test_encrypted_client_key_missing_passphrase(self): @@ -1141,7 +1263,7 @@ async def test_agent_signature_algs(self): if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') - for alg in ('ssh-rsa', 'rsa-sha2-256', 'rsa-sha2-512'): + for alg in ('rsa-sha2-256', 'rsa-sha2-512'): async with self.connect(username='ckey', signature_algs=[alg]): pass @@ -1161,7 +1283,8 @@ async def test_agent_auth_failure(self): async def test_agent_auth_unset(self): """Test connecting with no local keys and no ssh-agent configured""" - with patch.dict(os.environ, HOME='xxx', SSH_AUTH_SOCK=''): + with patch.dict(os.environ, HOME='xxx', USERPROFILE='xxx', + SSH_AUTH_SOCK=''): with self.assertRaises(asyncssh.PermissionDenied): await self.connect(username='ckey', known_hosts='.ssh/known_hosts') @@ -1193,9 +1316,9 @@ async def test_public_key_auth_not_preferred(self): async def test_public_key_signature_algs(self): """Test public key authentication with specific signature algorithms""" - for alg in ('ssh-rsa', 'rsa-sha2-256', 'rsa-sha2-512'): - async with self.connect(username='ckey', client_keys='ckey', - signature_algs=[alg]): + for alg in ('rsa-sha2-256', 'rsa-sha2-512'): + async with self.connect(username='ckey', agent_path=None, + client_keys='ckey', signature_algs=[alg]): pass @asynctest @@ -1209,10 +1332,13 @@ def skip_ext_info(self): return [] - with patch('asyncssh.connection.SSHConnection._get_ext_info_kex_alg', + with patch('asyncssh.connection.SSHConnection._get_extra_kex_algs', skip_ext_info): - async with self.connect(username='ckey', client_keys='ckey', - agent_path=None): + try: + async with self.connect(username='ckey', client_keys='ckey', + agent_path=None): + pass + except UnsupportedAlgorithm: # pragma: no cover pass @asynctest @@ -1287,7 +1413,7 @@ async def test_keypair_with_replaced_cert(self): @asynctest async def test_agent_keypair_with_replaced_cert(self): - """Test connecting with sn agent key with replaced cert""" + """Test connecting with an agent key with replaced cert""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') @@ -1466,14 +1592,6 @@ async def test_mismatched_client_signature_algs(self): await self.connect(username='ckey', client_keys='ckey', signature_algs=['rsa-sha2-256']) - @asynctest - async def test_client_signature_alg_fallback(self): - """Test fall back to default client key signature algorithm""" - - async with self.connect(username='ckey', client_keys='ckey', - signature_algs=['rsa-sha2-256', 'ssh-rsa']): - pass - class _TestSetAuthorizedKeys(ServerTestCase): """Unit tests for public key authentication with set_authorized_keys""" @@ -1594,7 +1712,7 @@ async def test_keypair_with_x509_cert(self): @asynctest async def test_agent_keypair_with_x509_cert(self): - """Test connecting with sn agent key with replaced X.509 cert""" + """Test connecting with an agent key with replaced X.509 cert""" if not self.agent_available(): # pragma: no cover self.skipTest('ssh-agent not available') @@ -1764,6 +1882,36 @@ async def test_password_auth(self): async with self.connect(username='pw', password='pw', client_keys=None): pass + @asynctest + async def test_password_auth_callable(self): + """Test connecting with a callable for password authentication""" + + async with self.connect(username='pw', password=lambda: 'pw', + client_keys=None): + pass + + @asynctest + async def test_password_auth_async_callable(self): + """Test connecting with an async callable for password authentication""" + + async def get_password(): + return 'pw' + + async with self.connect(username='pw', password=get_password, + client_keys=None): + pass + + @asynctest + async def test_password_auth_awaitable(self): + """Test connecting with an awaitable for password authentication""" + + async def get_password(): + return 'pw' + + async with self.connect(username='pw', password=get_password(), + client_keys=None): + pass + @asynctest async def test_password_auth_disabled(self): """Test connecting with password authentication disabled""" @@ -1941,7 +2089,7 @@ async def test_kbdint_auth_callback_multi(self): pass @asynctest - async def test_kbdint_auth_callback_faliure(self): + async def test_kbdint_auth_callback_failure(self): """Test failure connecting with keyboard-interactive auth callback""" with self.assertRaises(asyncssh.PermissionDenied): diff --git a/tests/test_forward.py b/tests/test_forward.py index cba318a..e0b2bce 100644 --- a/tests/test_forward.py +++ b/tests/test_forward.py @@ -37,7 +37,7 @@ from asyncssh.socks import SOCKS4_OK_RESPONSE, SOCKS5_OK_RESPONSE_HDR from .server import Server, ServerTestCase -from .util import asynctest, echo, make_certificate +from .util import asynctest, echo, make_certificate, try_remove def _echo_non_async(stdin, stdout, stderr=None): @@ -183,6 +183,18 @@ async def server_requested(self, listen_host, listen_port): return listen_host != 'fail' +class _TCPAcceptHandlerServer(Server): + """Server for testing forwarding accept handler""" + + async def server_requested(self, listen_host, listen_port): + """Handle a request to create a new socket listener""" + + def accept_handler(_orig_host: str, _orig_port: int) -> bool: + return True + + return accept_handler + + class _UNIXConnectionServer(Server): """Server for testing direct and forwarded UNIX domain connections""" @@ -217,6 +229,25 @@ async def unix_server_requested(self, listen_path): return listen_path != 'fail' +class _UpstreamForwardingServer(Server): + """Server for testing forwarding between SSH connections""" + + def __init__(self, upstream_conn): + super().__init__() + + self._upstream_conn = upstream_conn + + def connection_requested(self, dest_host, dest_port, orig_host, orig_port): + """Handle a request to create a new connection""" + + return self._upstream_conn + + def unix_connection_requested(self, dest_path): + """Handle a request to create a new UNIX domain connection""" + + return self._upstream_conn + + class _CheckForwarding(ServerTestCase): """Utility functions for AsyncSSH forwarding unit tests""" @@ -258,6 +289,24 @@ async def _check_echo_block(self, reader, writer): self.assertEqual(b''.join(data), result) + async def _check_local_connection(self, listen_port, delay=None): + """Open a local connection and test if an input line is echoed back""" + + reader, writer = await asyncio.open_connection('127.0.0.1', + listen_port) + + await self._check_echo_line(reader, writer, delay=delay) + + async def _check_local_unix_connection(self, listen_path): + """Open a local connection and test if an input line is echoed back""" + + # pylint doesn't think open_unix_connection exists + # pylint: disable=no-member + reader, writer = await asyncio.open_unix_connection(listen_path) + # pylint: enable=no-member + + await self._check_echo_line(reader, writer) + class _TestTCPForwarding(_CheckForwarding): """Unit tests for AsyncSSH TCP connection forwarding""" @@ -266,8 +315,8 @@ class _TestTCPForwarding(_CheckForwarding): async def start_server(cls): """Start an SSH server which supports TCP connection forwarding""" - return (await cls.create_server( - _TCPConnectionServer, authorized_client_keys='authorized_keys')) + return await cls.create_server( + _TCPConnectionServer, authorized_client_keys='authorized_keys') async def _check_connection(self, conn, dest_host='', dest_port=7, **kwargs): @@ -278,14 +327,6 @@ async def _check_connection(self, conn, dest_host='', await self._check_echo_block(reader, writer) - async def _check_local_connection(self, listen_port, delay=None): - """Open a local connection and test if an input line is echoed back""" - - reader, writer = await asyncio.open_connection('127.0.0.1', - listen_port) - - await self._check_echo_line(reader, writer, delay=delay) - @asynctest async def test_ssh_create_tunnel(self): """Test creating a tunneled SSH connection""" @@ -310,8 +351,8 @@ async def test_ssh_connect_tunnel(self): async def test_ssh_connect_tunnel_string(self): """Test connecting a tunneled SSH connection via string""" - async with self.connect(tunnel='%s:%d' % (self._server_addr, - self._server_port)) as conn: + async with self.connect(tunnel=f'{self._server_addr}:' + f'{self._server_port}') as conn: await self._check_connection(conn) @asynctest @@ -319,9 +360,8 @@ async def test_ssh_connect_tunnel_string_failed(self): """Test failed connection on a tunneled SSH connection via string""" with self.assertRaises(asyncssh.ChannelOpenError): - await asyncssh.connect('\xff', - tunnel='%s:%d' % (self._server_addr, - self._server_port)) + await asyncssh.connect( + '\xff', tunnel=f'{self._server_addr}:{self._server_port}') @asynctest async def test_proxy_jump(self): @@ -339,6 +379,23 @@ async def test_proxy_jump(self): finally: os.remove('.ssh/config') + @asynctest + async def test_proxy_jump_multiple(self): + """Test connecting a tunnneled SSH connection using ProxyJump""" + + write_file('.ssh/config', 'Host target\n' + ' Hostname localhost\n' + f' Port {self._server_port}\n' + f' ProxyJump localhost:{self._server_port},' + f'localhost:{self._server_port}\n' + 'IdentityFile ckey\n', 'w') + + try: + async with self.connect(host='target', username='ckey'): + pass + finally: + os.remove('.ssh/config') + @asynctest async def test_proxy_jump_encrypted_key(self): """Test ProxyJump with encrypted client key""" @@ -400,8 +457,10 @@ async def test_ssh_listen_tunnel(self): async with self.connect() as conn: async with conn.listen_ssh(port=0, server_factory=Server, - server_host_keys=['skey']) as server2: - listen_port = server2.get_port() + server_host_keys=['skey']) as server: + listen_port = server.get_port() + + self.assertEqual(server.get_addresses(), [('', listen_port)]) async with asyncssh.connect('127.0.0.1', listen_port, known_hosts=(['skey.pub'], [], [])): @@ -411,10 +470,9 @@ async def test_ssh_listen_tunnel(self): async def test_ssh_listen_tunnel_string(self): """Test opening a tunneled SSH listener via string""" - async with self.listen(tunnel='ckey@%s:%d' % (self._server_addr, - self._server_port), - server_factory=Server, - server_host_keys=['skey']) as server: + async with self.listen( + tunnel=f'ckey@{self._server_addr}:{self._server_port}', + server_factory=Server, server_host_keys=['skey']) as server: listen_port = server.get_port() async with asyncssh.connect('127.0.0.1', listen_port, @@ -426,11 +484,9 @@ async def test_ssh_listen_tunnel_string_failed(self): """Test open failure on a tunneled SSH listener via string""" with self.assertRaises(asyncssh.ChannelListenError): - await asyncssh.listen('\xff', - tunnel='%s:%d' % (self._server_addr, - self._server_port), - server_factory=Server, - server_host_keys=['skey']) + await asyncssh.listen( + '\xff', tunnel=f'{self._server_addr}:{self._server_port}', + server_factory=Server, server_host_keys=['skey']) @asynctest async def test_ssh_listen_tunnel_default_port(self): @@ -584,6 +640,65 @@ async def test_forward_local_port(self): await self._check_local_connection(listener.get_port(), delay=0.1) + @asynctest + async def test_forward_local_port_accept_handler(self): + """Test forwarding of a local port with an accept handler""" + + def accept_handler(_orig_host: str, _orig_port: int) -> bool: + return True + + async with self.connect() as conn: + async with conn.forward_local_port('', 0, '', 7, + accept_handler) as listener: + await self._check_local_connection(listener.get_port(), + delay=0.1) + + @asynctest + async def test_forward_local_port_accept_handler_denial(self): + """Test forwarding of a local port with an accept handler denial""" + + async def accept_handler(_orig_host: str, _orig_port: int) -> bool: + return False + + async with self.connect() as conn: + async with conn.forward_local_port('', 0, '', 7, + accept_handler) as listener: + listen_port = listener.get_port() + + reader, writer = await asyncio.open_connection('127.0.0.1', + listen_port) + + self.assertEqual((await reader.read()), b'') + + writer.close() + await maybe_wait_closed(writer) + + @unittest.skipIf(sys.platform == 'win32', + 'skip UNIX domain socket tests on Windows') + @asynctest + async def test_forward_local_path_to_port(self): + """Test forwarding of a local UNIX domain path to a remote TCP port""" + + async with self.connect() as conn: + async with conn.forward_local_path_to_port('local', '', 7): + await self._check_local_unix_connection('local') + + try_remove('local') + + @unittest.skipIf(sys.platform == 'win32', + 'skip UNIX domain socket tests on Windows') + @asynctest + async def test_forward_local_path_to_port_failure(self): + """Test failure forwarding a local UNIX domain path to a TCP port""" + + open('local', 'w').close() + + async with self.connect() as conn: + with self.assertRaises(OSError): + await conn.forward_local_path_to_port('local', '', 7) + + try_remove('local') + @asynctest async def test_forward_local_port_pause(self): """Test pause during forwarding of a local port""" @@ -642,6 +757,18 @@ async def test_forward_bind_error_ipv6(self): await conn.forward_local_port('', listener.get_port(), '', 7) + @unittest.skipIf(sys.platform == 'win32', + 'skip UNIX domain socket tests on Windows') + @asynctest + async def test_forward_port_to_path_bind_error(self): + """Test error binding a local port forwarding to remote path""" + + async with self.connect() as conn: + async with conn.forward_local_port('0.0.0.0', 0, '', 7) as listener: + with self.assertRaises(OSError): + await conn.forward_local_port_to_path( + '', listener.get_port(), '') + @asynctest async def test_forward_connect_error(self): """Test error connecting a local forwarding port""" @@ -688,6 +815,24 @@ async def test_forward_remote_port(self): server.close() await server.wait_closed() + @unittest.skipIf(sys.platform == 'win32', + 'skip UNIX domain socket tests on Windows') + @asynctest + async def test_forward_remote_port_to_path(self): + """Test forwarding of a remote port to a local UNIX domain socket""" + + server = await asyncio.start_unix_server(echo, 'local') + + async with self.connect() as conn: + async with conn.forward_remote_port_to_path( + '', 0, 'local') as listener: + await self._check_local_connection(listener.get_port()) + + server.close() + await server.wait_closed() + + try_remove('local') + @asynctest async def test_forward_remote_specific_port(self): """Test forwarding of a specific remote port""" @@ -750,6 +895,25 @@ async def test_cancel_forward_remote_port_invalid_unicode(self): self.assertEqual(pkttype, asyncssh.MSG_REQUEST_FAILURE) + @asynctest + async def test_upstream_forward_local_port(self): + """Test upstream forwarding of a local port""" + + def upstream_server(): + """Return a server capable of forwarding between SSH connections""" + + return _UpstreamForwardingServer(upstream_conn) + + async with self.connect() as upstream_conn: + upstream_listener = await self.create_server(upstream_server) + upstream_port = upstream_listener.get_port() + + async with self.connect('127.0.0.1', upstream_port) as conn: + async with conn.forward_local_port('', 0, '', 7) as listener: + await self._check_local_connection(listener.get_port()) + + upstream_listener.close() + @asynctest async def test_add_channel_after_close(self): """Test opening a connection after a close""" @@ -790,6 +954,33 @@ async def test_listener_close_on_conn_close(self): await listener.wait_closed() +class _TestTCPForwardingAcceptHandler(_CheckForwarding): + """Unit tests for TCP forwarding with accept handler""" + + @classmethod + async def start_server(cls): + """Start an SSH server which supports TCP connection forwarding""" + + return await cls.create_server( + _TCPAcceptHandlerServer, authorized_client_keys='authorized_keys') + + @asynctest + async def test_forward_remote_port_accept_handler(self): + """Test forwarding of a remote port with accept handler""" + + server = await asyncio.start_server(echo, None, 0, + family=socket.AF_INET) + server_port = server.sockets[0].getsockname()[1] + + async with self.connect() as conn: + async with conn.forward_remote_port( + '', 0, '127.0.0.1', server_port) as listener: + await self._check_local_connection(listener.get_port()) + + server.close() + await server.wait_closed() + + class _TestAsyncTCPForwarding(_TestTCPForwarding): """Unit tests for AsyncSSH TCP connection forwarding with async return""" @@ -810,8 +1001,8 @@ class _TestUNIXForwarding(_CheckForwarding): async def start_server(cls): """Start an SSH server which supports UNIX connection forwarding""" - return (await cls.create_server( - _UNIXConnectionServer, authorized_client_keys='authorized_keys')) + return await cls.create_server( + _UNIXConnectionServer, authorized_client_keys='authorized_keys') async def _check_unix_connection(self, conn, dest_path='/echo', **kwargs): """Open a UNIX connection and test if an input line is echoed back""" @@ -822,16 +1013,6 @@ async def _check_unix_connection(self, conn, dest_path='/echo', **kwargs): await self._check_echo_line(reader, writer, encoded=True) - async def _check_local_unix_connection(self, listen_path): - """Open a local connection and test if an input line is echoed back""" - - # pylint doesn't think open_unix_connection exists - # pylint: disable=no-member - reader, writer = await asyncio.open_unix_connection(listen_path) - # pylint: enable=no-member - - await self._check_echo_line(reader, writer) - @asynctest async def test_unix_connection(self): """Test opening a remote UNIX connection""" @@ -891,7 +1072,7 @@ async def test_unix_server(self): await listener.wait_closed() listener.close() - os.remove('echo') + try_remove('echo') @asynctest async def test_unix_server_open(self): @@ -924,7 +1105,7 @@ async def test_unix_server_non_async(self): async with conn.start_unix_server(_unix_listener_non_async, path): await self._check_local_unix_connection('echo') - os.remove('echo') + try_remove('echo') @asynctest async def test_unix_server_failure(self): @@ -942,7 +1123,65 @@ async def test_forward_local_path(self): async with conn.forward_local_path('local', '/echo'): await self._check_local_unix_connection('local') - os.remove('local') + try_remove('local') + + @asynctest + async def test_forward_local_port_to_path_accept_handler(self): + """Test forwarding of port to UNIX path with accept handler""" + + def accept_handler(_orig_host: str, _orig_port: int) -> bool: + return True + + async with self.connect() as conn: + async with conn.forward_local_port_to_path( + '', 0, '/echo', accept_handler) as listener: + await self._check_local_connection(listener.get_port(), + delay=0.1) + + @asynctest + async def test_forward_local_port_to_path_accept_handler_denial(self): + """Test forwarding of port to UNIX path with accept handler denial""" + + async def accept_handler(_orig_host: str, _orig_port: int) -> bool: + return False + + async with self.connect() as conn: + async with conn.forward_local_port_to_path( + '', 0, '/echo', accept_handler) as listener: + listen_port = listener.get_port() + + reader, writer = await asyncio.open_connection('127.0.0.1', + listen_port) + + self.assertEqual((await reader.read()), b'') + + writer.close() + await maybe_wait_closed(writer) + + @asynctest + async def test_forward_local_port_to_path(self): + """Test forwarding of a local port to a remote UNIX domain socket""" + + async with self.connect() as conn: + async with conn.forward_local_port_to_path('', 0, + '/echo') as listener: + await self._check_local_connection(listener.get_port(), + delay=0.1) + + @asynctest + async def test_forward_specific_local_port_to_path(self): + """Test forwarding of a specific local port to a UNIX domain socket""" + + sock = socket.socket() + sock.bind(('', 0)) + listen_port = sock.getsockname()[1] + sock.close() + + async with self.connect() as conn: + async with conn.forward_local_port_to_path( + '', listen_port, '/echo') as listener: + await self._check_local_connection(listener.get_port(), + delay=0.1) @asynctest async def test_forward_remote_path(self): @@ -962,8 +1201,28 @@ async def test_forward_remote_path(self): server.close() await server.wait_closed() - os.remove('echo') - os.remove('local') + try_remove('echo') + try_remove('local') + + @asynctest + async def test_forward_remote_path_to_port(self): + """Test forwarding of a remote UNIX domain path to a local TCP port""" + + server = await asyncio.start_server(echo, None, 0, + family=socket.AF_INET) + server_port = server.sockets[0].getsockname()[1] + + path = os.path.abspath('echo') + + async with self.connect() as conn: + async with conn.forward_remote_path_to_port( + path, '127.0.0.1', server_port): + await self._check_local_unix_connection('echo') + + server.close() + await server.wait_closed() + + try_remove('echo') @asynctest async def test_forward_remote_path_failure(self): @@ -977,7 +1236,7 @@ async def test_forward_remote_path_failure(self): with self.assertRaises(asyncssh.ChannelListenError): await conn.forward_remote_path(path, 'local') - os.remove('echo') + try_remove('echo') @asynctest async def test_forward_remote_path_not_permitted(self): @@ -1012,6 +1271,25 @@ async def test_cancel_forward_remote_path_invalid_unicode(self): self.assertEqual(pkttype, asyncssh.MSG_REQUEST_FAILURE) + @asynctest + async def test_upstream_forward_local_path(self): + """Test upstream forwarding of a local path""" + + def upstream_server(): + """Return a server capable of forwarding between SSH connections""" + + return _UpstreamForwardingServer(upstream_conn) + + async with self.connect() as upstream_conn: + upstream_listener = await self.create_server(upstream_server) + upstream_port = upstream_listener.get_port() + + async with self.connect('127.0.0.1', upstream_port) as conn: + async with conn.forward_local_path('local', '/echo'): + await self._check_local_unix_connection('local') + + upstream_listener.close() + class _TestAsyncUNIXForwarding(_TestUNIXForwarding): """Unit tests for AsyncSSH UNIX connection forwarding with async return""" @@ -1032,8 +1310,8 @@ class _TestSOCKSForwarding(_CheckForwarding): async def start_server(cls): """Start an SSH server which supports TCP connection forwarding""" - return (await cls.create_server( - _TCPConnectionServer, authorized_client_keys='authorized_keys')) + return await cls.create_server( + _TCPConnectionServer, authorized_client_keys='authorized_keys') async def _check_early_error(self, reader, writer, data): """Check errors in the initial SOCKS message""" diff --git a/tests/test_kex.py b/tests/test_kex.py index 208298c..a03fd0e 100644 --- a/tests/test_kex.py +++ b/tests/test_kex.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2020 by Ron Frederick and others. +# Copyright (c) 2015-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -21,6 +21,7 @@ """Unit tests for key exchange""" import asyncio +import inspect import unittest from hashlib import sha1 @@ -28,13 +29,14 @@ import asyncssh from asyncssh.crypto import curve25519_available, curve448_available -from asyncssh.crypto import Curve25519DH, Curve448DH, ECDH +from asyncssh.crypto import sntrup_available +from asyncssh.crypto import Curve25519DH, Curve448DH, ECDH, PQDH from asyncssh.kex_dh import MSG_KEXDH_INIT, MSG_KEXDH_REPLY from asyncssh.kex_dh import MSG_KEX_DH_GEX_REQUEST, MSG_KEX_DH_GEX_GROUP from asyncssh.kex_dh import MSG_KEX_DH_GEX_INIT, MSG_KEX_DH_GEX_REPLY, _KexDHGex from asyncssh.kex_dh import MSG_KEX_ECDH_INIT, MSG_KEX_ECDH_REPLY -from asyncssh.kex_dh import MSG_KEXGSS_INIT, MSG_KEXGSS_COMPLETE -from asyncssh.kex_dh import MSG_KEXGSS_ERROR +from asyncssh.kex_dh import MSG_KEXGSS_INIT, MSG_KEXGSS_HOSTKEY +from asyncssh.kex_dh import MSG_KEXGSS_COMPLETE, MSG_KEXGSS_ERROR from asyncssh.kex_rsa import MSG_KEXRSA_PUBKEY, MSG_KEXRSA_SECRET from asyncssh.kex_rsa import MSG_KEXRSA_DONE from asyncssh.gss import GSSClient, GSSServer @@ -42,25 +44,26 @@ from asyncssh.packet import SSHPacket, Boolean, Byte, MPInt, String from asyncssh.public_key import decode_ssh_public_key -from .util import asynctest, gss_available, patch_gss +from .util import asynctest, get_test_key, gss_available, patch_gss from .util import AsyncTestCase, ConnectionStub class _KexConnectionStub(ConnectionStub): """Connection stub class to test key exchange""" - def __init__(self, alg, gss, peer, server=False): + def __init__(self, alg, gss, duplicate=0, peer=None, server=False): super().__init__(peer, server) self._gss = gss self._key_waiter = asyncio.Future() + self._duplicate = duplicate self._kex = get_kex(self, alg) - def start(self): + async def start(self): """Start key exchange""" - self._kex.start() + await self._kex.start() def connection_lost(self, exc): """Handle the closing of a connection""" @@ -70,12 +73,15 @@ def connection_lost(self, exc): def enable_gss_kex_auth(self): """Ignore request to enable GSS key exchange authentication""" - def process_packet(self, data): + async def process_packet(self, data): """Process an incoming packet""" packet = SSHPacket(data) pkttype = packet.get_byte() - self._kex.process_packet(pkttype, None, packet) + result = self._kex.process_packet(pkttype, None, packet) + + if inspect.isawaitable(result): + await result def get_hash_prefix(self): """Return the bytes used in calculating unique connection hashes""" @@ -99,89 +105,101 @@ def get_gss_context(self): return self._gss - def simulate_dh_init(self, e): + def send_packet(self, pkttype, *args, **kwargs): + """Duplicate sending packets of a specific type""" + + super().send_packet(pkttype, *args) + + if pkttype == self._duplicate: + super().send_packet(pkttype, *args, **kwargs) + + async def simulate_dh_init(self, e): """Simulate receiving a DH init packet""" - self.process_packet(Byte(MSG_KEXDH_INIT) + MPInt(e)) + await self.process_packet(Byte(MSG_KEXDH_INIT) + MPInt(e)) - def simulate_dh_reply(self, host_key_data, f, sig): + async def simulate_dh_reply(self, host_key_data, f, sig): """Simulate receiving a DH reply packet""" - self.process_packet(b''.join((Byte(MSG_KEXDH_REPLY), - String(host_key_data), - MPInt(f), String(sig)))) + await self.process_packet(b''.join((Byte(MSG_KEXDH_REPLY), + String(host_key_data), + MPInt(f), String(sig)))) - def simulate_dh_gex_group(self, p, g): + async def simulate_dh_gex_group(self, p, g): """Simulate receiving a DH GEX group packet""" - self.process_packet(Byte(MSG_KEX_DH_GEX_GROUP) + MPInt(p) + MPInt(g)) + await self.process_packet(Byte(MSG_KEX_DH_GEX_GROUP) + + MPInt(p) + MPInt(g)) - def simulate_dh_gex_init(self, e): + async def simulate_dh_gex_init(self, e): """Simulate receiving a DH GEX init packet""" - self.process_packet(Byte(MSG_KEX_DH_GEX_INIT) + MPInt(e)) + await self.process_packet(Byte(MSG_KEX_DH_GEX_INIT) + MPInt(e)) - def simulate_dh_gex_reply(self, host_key_data, f, sig): + async def simulate_dh_gex_reply(self, host_key_data, f, sig): """Simulate receiving a DH GEX reply packet""" - self.process_packet(b''.join((Byte(MSG_KEX_DH_GEX_REPLY), - String(host_key_data), + await self.process_packet(b''.join((Byte(MSG_KEX_DH_GEX_REPLY), + String(host_key_data), MPInt(f), String(sig)))) - def simulate_gss_complete(self, f, sig): + async def simulate_gss_complete(self, f, sig): """Simulate receiving a GSS complete packet""" - self.process_packet(b''.join((Byte(MSG_KEXGSS_COMPLETE), MPInt(f), - String(sig), Boolean(False)))) + await self.process_packet(b''.join((Byte(MSG_KEXGSS_COMPLETE), + MPInt(f), String(sig), + Boolean(False)))) - def simulate_ecdh_init(self, client_pub): + async def simulate_ecdh_init(self, client_pub): """Simulate receiving an ECDH init packet""" - self.process_packet(Byte(MSG_KEX_ECDH_INIT) + String(client_pub)) + await self.process_packet(Byte(MSG_KEX_ECDH_INIT) + String(client_pub)) - def simulate_ecdh_reply(self, host_key_data, server_pub, sig): + async def simulate_ecdh_reply(self, host_key_data, server_pub, sig): """Simulate receiving ab ECDH reply packet""" - self.process_packet(b''.join((Byte(MSG_KEX_ECDH_REPLY), - String(host_key_data), - String(server_pub), String(sig)))) + await self.process_packet(b''.join((Byte(MSG_KEX_ECDH_REPLY), + String(host_key_data), + String(server_pub), String(sig)))) - def simulate_rsa_pubkey(self, host_key_data, trans_key_data): + async def simulate_rsa_pubkey(self, host_key_data, trans_key_data): """Simulate receiving an RSA pubkey packet""" - self.process_packet(Byte(MSG_KEXRSA_PUBKEY) + String(host_key_data) + - String(trans_key_data)) + await self.process_packet(Byte(MSG_KEXRSA_PUBKEY) + + String(host_key_data) + + String(trans_key_data)) - def simulate_rsa_secret(self, encrypted_k): + async def simulate_rsa_secret(self, encrypted_k): """Simulate receiving an RSA secret packet""" - self.process_packet(Byte(MSG_KEXRSA_SECRET) + String(encrypted_k)) + await self.process_packet(Byte(MSG_KEXRSA_SECRET) + + String(encrypted_k)) - def simulate_rsa_done(self, sig): + async def simulate_rsa_done(self, sig): """Simulate receiving an RSA done packet""" - self.process_packet(Byte(MSG_KEXRSA_DONE) + String(sig)) + await self.process_packet(Byte(MSG_KEXRSA_DONE) + String(sig)) class _KexClientStub(_KexConnectionStub): """Stub class for client connection""" @classmethod - def make_pair(cls, alg, gss_host=None): + def make_pair(cls, alg, gss_host=None, duplicate=0): """Make a client and server connection pair to test key exchange""" - client_conn = cls(alg, gss_host) + client_conn = cls(alg, gss_host, duplicate) return client_conn, client_conn.get_peer() - def __init__(self, alg, gss_host): - server_conn = _KexServerStub(alg, gss_host, self) + def __init__(self, alg, gss_host, duplicate): + server_conn = _KexServerStub(alg, gss_host, duplicate, peer=self) if gss_host: - gss = GSSClient(gss_host, 'delegate' in gss_host) + gss = GSSClient(gss_host, None, 'delegate' in gss_host) else: gss = None - super().__init__(alg, gss, server_conn) + super().__init__(alg, gss, duplicate, peer=server_conn) def connection_lost(self, exc): """Handle the closing of a connection""" @@ -202,14 +220,14 @@ def validate_server_host_key(self, host_key_data): class _KexServerStub(_KexConnectionStub): """Stub class for server connection""" - def __init__(self, alg, gss_host, peer): - gss = GSSServer(gss_host) if gss_host else None - super().__init__(alg, gss, peer, True) + def __init__(self, alg, gss_host, duplicate, peer): + gss = GSSServer(gss_host, None) if gss_host else None + super().__init__(alg, gss, duplicate, peer, True) if gss_host and 'no_host_key' in gss_host: self._server_host_key = None else: - priv_key = asyncssh.generate_private_key('ssh-rsa') + priv_key = get_test_key('ecdsa-sha2-nistp256') self._server_host_key = asyncssh.load_keypairs(priv_key)[0] def connection_lost(self, exc): @@ -236,8 +254,8 @@ async def _check_kex(self, alg, gss_host=None): client_conn, server_conn = _KexClientStub.make_pair(alg, gss_host) try: - client_conn.start() - server_conn.start() + await client_conn.start() + await server_conn.start() self.assertEqual((await client_conn.get_key()), (await server_conn.get_key())) @@ -313,25 +331,27 @@ async def test_dh_errors(self): with self.subTest('Init sent to client'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.process_packet(Byte(MSG_KEXDH_INIT)) + await client_conn.process_packet(Byte(MSG_KEXDH_INIT)) with self.subTest('Reply sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.process_packet(Byte(MSG_KEXDH_REPLY)) + await server_conn.process_packet(Byte(MSG_KEXDH_REPLY)) with self.subTest('Invalid e value'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_dh_init(0) + await server_conn.simulate_dh_init(0) with self.subTest('Invalid f value'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.start() - client_conn.simulate_dh_reply(host_key.public_data, 0, b'') + await client_conn.start() + await client_conn.simulate_dh_reply(host_key.public_data, + 0, b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): - client_conn.start() - client_conn.simulate_dh_reply(host_key.public_data, 1, b'') + await client_conn.start() + await client_conn.simulate_dh_reply(host_key.public_data, + 2, b'') client_conn.close() server_conn.close() @@ -345,31 +365,46 @@ async def test_dh_gex_errors(self): with self.subTest('Request sent to client'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.process_packet(Byte(MSG_KEX_DH_GEX_REQUEST)) + await client_conn.process_packet(Byte(MSG_KEX_DH_GEX_REQUEST)) with self.subTest('Group sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_dh_gex_group(1, 2) + await server_conn.simulate_dh_gex_group(1, 2) with self.subTest('Init sent to client'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.simulate_dh_gex_init(1) + await client_conn.simulate_dh_gex_init(1) with self.subTest('Init sent before group'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_dh_gex_init(1) + await server_conn.simulate_dh_gex_init(1) with self.subTest('Reply sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_dh_gex_reply(b'', 1, b'') + await server_conn.simulate_dh_gex_reply(b'', 1, b'') with self.subTest('Reply sent before group'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.simulate_dh_gex_reply(b'', 1, b'') + await client_conn.simulate_dh_gex_reply(b'', 1, b'') client_conn.close() server_conn.close() + @asynctest + async def test_dh_gex_multiple_messages(self): + """Unit test duplicate messages in DH group exchange""" + + for pkttype in (MSG_KEX_DH_GEX_REQUEST, MSG_KEX_DH_GEX_GROUP): + client_conn, server_conn = _KexClientStub.make_pair( + b'diffie-hellman-group-exchange-sha1', duplicate=pkttype) + + with self.assertRaises(asyncssh.ProtocolError): + await client_conn.start() + await client_conn.get_key() + + client_conn.close() + server_conn.close() + @unittest.skipUnless(gss_available, 'GSS not available') @asynctest async def test_gss_errors(self): @@ -380,26 +415,36 @@ async def test_gss_errors(self): with self.subTest('Init sent to client'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.process_packet(Byte(MSG_KEXGSS_INIT)) + await client_conn.process_packet(Byte(MSG_KEXGSS_INIT)) + + with self.subTest('Host key sent to server'): + with self.assertRaises(asyncssh.ProtocolError): + await server_conn.process_packet(Byte(MSG_KEXGSS_HOSTKEY)) + + with self.subTest('Host key sent twice to client'): + with self.assertRaises(asyncssh.ProtocolError): + await client_conn.process_packet(Byte(MSG_KEXGSS_HOSTKEY)) + await client_conn.process_packet(Byte(MSG_KEXGSS_HOSTKEY)) with self.subTest('Complete sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.process_packet(Byte(MSG_KEXGSS_COMPLETE)) + await server_conn.process_packet(Byte(MSG_KEXGSS_COMPLETE)) with self.subTest('Exchange failed to complete'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.simulate_gss_complete(1, b'succeed') + await client_conn.simulate_gss_complete(1, b'succeed') with self.subTest('Error sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.process_packet(Byte(MSG_KEXGSS_ERROR)) + await server_conn.process_packet(Byte(MSG_KEXGSS_ERROR)) client_conn.close() server_conn.close() with self.subTest('Signature verification failure'): with self.assertRaises(asyncssh.KeyExchangeFailed): - await self._check_kex(b'gss-group1-sha1-mech', '0,fail') + await self._check_kex(b'gss-group1-sha1-mech', + '0,verify_error') with self.subTest('Empty token in init'): with self.assertRaises(asyncssh.ProtocolError): @@ -444,31 +489,32 @@ async def test_ecdh_errors(self): with self.subTest('Init sent to client'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.simulate_ecdh_init(b'') + await client_conn.simulate_ecdh_init(b'') with self.subTest('Invalid client public key'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_ecdh_init(b'') + await server_conn.simulate_ecdh_init(b'') with self.subTest('Reply sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_ecdh_reply(b'', b'', b'') + await server_conn.simulate_ecdh_reply(b'', b'', b'') with self.subTest('Invalid server host key'): with self.assertRaises(asyncssh.KeyImportError): - client_conn.simulate_ecdh_reply(b'', b'', b'') + await client_conn.simulate_ecdh_reply(b'', b'', b'') with self.subTest('Invalid server public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() - client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + b'', b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): host_key = server_conn.get_server_host_key() server_pub = ECDH(b'nistp256').get_public() - client_conn.simulate_ecdh_reply(host_key.public_data, - server_pub, b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + server_pub, b'') client_conn.close() server_conn.close() @@ -483,26 +529,27 @@ async def test_curve25519dh_errors(self): with self.subTest('Invalid client public key'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_ecdh_init(b'') + await server_conn.simulate_ecdh_init(b'') with self.subTest('Invalid server public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() - client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + b'', b'') with self.subTest('Invalid peer public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() server_pub = b'\x01' + 31*b'\x00' - client_conn.simulate_ecdh_reply(host_key.public_data, - server_pub, b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + server_pub, b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): host_key = server_conn.get_server_host_key() server_pub = Curve25519DH().get_public() - client_conn.simulate_ecdh_reply(host_key.public_data, - server_pub, b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + server_pub, b'') client_conn.close() server_conn.close() @@ -517,26 +564,55 @@ async def test_curve448dh_errors(self): with self.subTest('Invalid client public key'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_ecdh_init(b'') + await server_conn.simulate_ecdh_init(b'') with self.subTest('Invalid server public key'): with self.assertRaises(asyncssh.ProtocolError): host_key = server_conn.get_server_host_key() - client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'') - - with self.subTest('Invalid peer public key'): - with self.assertRaises(asyncssh.ProtocolError): - host_key = server_conn.get_server_host_key() - server_pub = b'\x01' + 55*b'\x00' - client_conn.simulate_ecdh_reply(host_key.public_data, - server_pub, b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + b'', b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): host_key = server_conn.get_server_host_key() server_pub = Curve448DH().get_public() - client_conn.simulate_ecdh_reply(host_key.public_data, - server_pub, b'') + await client_conn.simulate_ecdh_reply(host_key.public_data, + server_pub, b'') + + client_conn.close() + server_conn.close() + + @unittest.skipUnless(sntrup_available, 'SNTRUP761 not available') + @asynctest + async def test_sntrup761dh_errors(self): + """Unit test error conditions in SNTRUP761 key exchange""" + + pqdh = PQDH(b'sntrup761') + + client_conn, server_conn = \ + _KexClientStub.make_pair(b'sntrup761x25519-sha512@openssh.com') + + with self.subTest('Invalid client SNTRUP761 public key'): + with self.assertRaises(asyncssh.ProtocolError): + await server_conn.simulate_ecdh_init(b'') + + with self.subTest('Invalid client Curve25519 public key'): + with self.assertRaises(asyncssh.ProtocolError): + pub = pqdh.pubkey_bytes * b'\0' + await server_conn.simulate_ecdh_init(pub) + + with self.subTest('Invalid server SNTRUP761 public key'): + with self.assertRaises(asyncssh.ProtocolError): + host_key = server_conn.get_server_host_key() + await client_conn.simulate_ecdh_reply(host_key.public_data, + b'', b'') + + with self.subTest('Invalid server Curve25519 public key'): + with self.assertRaises(asyncssh.ProtocolError): + host_key = server_conn.get_server_host_key() + ciphertext = pqdh.ciphertext_bytes * b'\0' + await client_conn.simulate_ecdh_reply(host_key.public_data, + ciphertext, b'') client_conn.close() server_conn.close() @@ -550,32 +626,32 @@ async def test_rsa_errors(self): with self.subTest('Pubkey sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_rsa_pubkey(b'', b'') + await server_conn.simulate_rsa_pubkey(b'', b'') with self.subTest('Secret sent to client'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.simulate_rsa_secret(b'') + await client_conn.simulate_rsa_secret(b'') with self.subTest('Done sent to server'): with self.assertRaises(asyncssh.ProtocolError): - server_conn.simulate_rsa_done(b'') + await server_conn.simulate_rsa_done(b'') with self.subTest('Invalid transient public key'): with self.assertRaises(asyncssh.ProtocolError): - client_conn.simulate_rsa_pubkey(b'', b'') + await client_conn.simulate_rsa_pubkey(b'', b'') with self.subTest('Invalid encrypted secret'): with self.assertRaises(asyncssh.KeyExchangeFailed): - server_conn.start() - server_conn.simulate_rsa_secret(b'') + await server_conn.start() + await server_conn.simulate_rsa_secret(b'') with self.subTest('Invalid signature'): with self.assertRaises(asyncssh.KeyExchangeFailed): host_key = server_conn.get_server_host_key() - trans_key = asyncssh.generate_private_key('ssh-rsa', 2048) - client_conn.simulate_rsa_pubkey(host_key.public_data, - trans_key.public_data) - client_conn.simulate_rsa_done(b'') + trans_key = get_test_key('ssh-rsa', 2048) + await client_conn.simulate_rsa_pubkey(host_key.public_data, + trans_key.public_data) + await client_conn.simulate_rsa_done(b'') client_conn.close() server_conn.close() diff --git a/tests/test_known_hosts.py b/tests/test_known_hosts.py index 621f2ca..94ba7a7 100644 --- a/tests/test_known_hosts.py +++ b/tests/test_known_hosts.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2018 by Ron Frederick and others. +# Copyright (c) 2015-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -27,7 +27,7 @@ import asyncssh -from .util import TempDirTestCase, x509_available +from .util import TempDirTestCase, get_test_key, x509_available if x509_available: # pragma: no branch from asyncssh.crypto import X509NamePattern @@ -58,16 +58,16 @@ def setUpClass(cls): for keylist, imported_keylist in zip(cls.keylists[:3], cls.imported_keylists[:3]): - for _ in range(3): - key = asyncssh.generate_private_key('ssh-rsa') + for i in range(3): + key = get_test_key('ssh-rsa', i) keylist.append(key.export_public_key().decode('ascii')) imported_keylist.append(key.convert_to_public()) if x509_available: # pragma: no branch for keylist, imported_keylist in zip(cls.keylists[3:5], cls.imported_keylists[3:5]): - for _ in range(2): - key = asyncssh.generate_private_key('ssh-rsa') + for i in range(3, 5): + key = get_test_key('ssh-rsa', i) cert = key.generate_x509_user_certificate(key, 'OU=user', 'OU=user') keylist.append( @@ -108,7 +108,7 @@ def call_match(host, addr, port): for prefix, patlist, keys in zip(prefixes, patlists, self.keylists): for pattern, key in zip(patlist, keys): - known_hosts += '%s%s %s' % (prefix, pattern, key) + known_hosts += f'{prefix}{pattern} {key}' if from_file: with open('known_hosts', 'w') as f: @@ -234,12 +234,12 @@ def test_missing_key_with_tag(self): self.check_match(b'@cert-authority xxx\n') def test_invalid_key(self): - """Test for line with invaid key""" + """Test for line with invalid key""" self.check_match(b'xxx yyy\n', ([], [], [], [], [], [], [])) def test_invalid_marker(self): - """Test for line with invaid marker""" + """Test for line with invalid marker""" with self.assertRaises(ValueError): self.check_match(b'@xxx yyy zzz\n') diff --git a/tests/test_logging.py b/tests/test_logging.py index 9317ef2..459d632 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2019 by Ron Frederick and others. +# Copyright (c) 2017-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -138,7 +138,7 @@ async def test_channel_log(self): self.assertEqual(len(log.records), 1) self.assertRegex(log.records[0].msg, - r'\[conn=\d+, chan=%s\] Test' % i) + rf'\[conn=\d+, chan={i}\] Test') @asynctest async def test_stream_log(self): diff --git a/tests/test_pkcs11.py b/tests/test_pkcs11.py index 4044506..bacf05c 100644 --- a/tests/test_pkcs11.py +++ b/tests/test_pkcs11.py @@ -165,6 +165,10 @@ async def test_pkcs11_load_keys(self): for sig_alg in key.sig_algorithms: sig_alg = sig_alg.decode('ascii') + # Disable unit tests that involve SHA-1 hashes + if sig_alg in ('ssh-rsa', 'x509v3-ssh-rsa'): + continue + with self.subTest(key=key.get_comment(), sig_alg=sig_alg): async with self.connect( username='ckey', pkcs11_provider='xxx', diff --git a/tests/test_process.py b/tests/test_process.py index bd1b3aa..326feb2 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -1,4 +1,4 @@ -# Copyright (c) 2016-2021 by Ron Frederick and others. +# Copyright (c) 2016-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -34,12 +34,18 @@ from .server import ServerTestCase from .util import asynctest, echo +if sys.platform != 'win32': # pragma: no branch + import fcntl + import struct + import termios + try: import aiofiles _aiofiles_available = True except ImportError: # pragma: no cover _aiofiles_available = False + async def _handle_client(process): """Handle a new client request""" @@ -64,6 +70,16 @@ async def _handle_client(process): elif action == 'env': process.channel.set_encoding('utf-8') process.stdout.write(process.env.get('TEST', '')) + elif action.startswith('redirect '): + _, addr, port, action = action.split(None, 3) + + async with asyncssh.connect(addr, int(port)) as conn: + upstream_process = await conn.create_process( + command=action, encoding=None, term_type=process.term_type, + stdin=process.stdin, stdout=process.stdout) + + result = await upstream_process.wait() + process.exit_with_signal(*result.exit_signal) elif action == 'redirect_stdin': await process.redirect_stdin(process.stdout) await process.stdout.drain() @@ -89,7 +105,24 @@ async def _handle_client(process): await process.stdin.readline() except asyncssh.TerminalSizeChanged as exc: process.exit_with_signal('ABRT', False, - '%sx%s' % (exc.width, exc.height)) + f'{exc.width}x{exc.height}') + elif action == 'term_size_tty': + master, slave = os.openpty() + await process.redirect_stdin(master, recv_eof=False) + process.stdout.write(b'\n') + + await process.stdin.readline() + size = fcntl.ioctl(slave, termios.TIOCGWINSZ, 8*b'\0') + height, width, _, _ = struct.unpack('hhhh', size) + process.stdout.write(f'{width}x{height}'.encode()) + os.close(slave) + elif action == 'term_size_nontty': + rpipe, wpipe = os.pipe() + await process.redirect_stdin(wpipe) + process.stdout.write(b'\n') + + await process.stdin.readline() + os.close(rpipe) elif action == 'timeout': process.channel.set_encoding('utf-8') process.stdout.write('Sleeping') @@ -427,7 +460,7 @@ async def test_ignoring_invalid_unicode(self): async def test_incomplete_unicode(self): """Test incomplete Unicode data""" - data = '\u2000'.encode('utf-8')[:2] + data = '\u2000'.encode()[:2] with open('stdin', 'wb') as file: file.write(data) @@ -553,7 +586,7 @@ async def test_stdin_open_file(self): with open('stdin', 'w') as file: file.write(data) - file = open('stdin', 'r') + file = open('stdin') async with self.connect() as conn: result = await conn.run('echo', stdin=file) @@ -626,6 +659,77 @@ async def test_stdin_process(self): self.assertEqual(result.stdout, data) self.assertEqual(result.stderr, data) + @asynctest + async def test_forward_terminal_size(self): + """Test forwarding a terminal size change""" + + async with self.connect() as conn: + cmd = f'redirect {self._server_addr} {self._server_port} term_size' + process = await conn.create_process(cmd, term_type='ansi') + process.change_terminal_size(80, 24) + result = await process.wait() + + self.assertEqual(result.exit_signal[2], '80x24') + + @unittest.skipIf(sys.platform == 'win32', + 'skip TTY terminal size tests on Windows') + @asynctest + async def test_forward_terminal_size_tty(self): + """Test forwarding a terminal size change to a remote tty""" + + async with self.connect() as conn: + process = await conn.create_process('term_size_tty', + term_type='ansi') + await process.stdout.readline() + process.change_terminal_size(80, 24) + process.stdin.write_eof() + result = await process.wait() + + self.assertEqual(result.stdout, '80x24') + + @unittest.skipIf(sys.platform == 'win32', + 'skip TTY terminal size tests on Windows') + @asynctest + async def test_forward_terminal_size_nontty(self): + """Test forwarding a terminal size change to a remote non-tty""" + + async with self.connect() as conn: + process = await conn.create_process('term_size_nontty', + term_type='ansi') + await process.stdout.readline() + process.change_terminal_size(80, 24) + process.stdin.write_eof() + result = await process.wait() + + self.assertEqual(result.stdout, '') + + @asynctest + async def test_forward_break(self): + """Test forwarding a break""" + + async with self.connect() as conn: + cmd = f'redirect {self._server_addr} {self._server_port} break' + process = await conn.create_process(cmd) + process.send_break(1000) + result = await process.wait() + + self.assertEqual(result.exit_signal[2], '1000') + + @asynctest + async def test_forward_signal(self): + """Test forwarding a signal""" + + async with self.connect() as conn: + cmd = f'redirect {self._server_addr} {self._server_port} echo' + process = await conn.create_process(cmd) + process.stdin.write('\n') + await process.stdout.readline() + process.send_signal('INT') + result = await process.wait() + + self.assertEqual(result.exit_signal[0], 'INT') + self.assertEqual(result.returncode, -SIGINT) + @unittest.skipIf(sys.platform == 'win32', 'skip asyncio.subprocess tests on Windows') @asynctest @@ -669,7 +773,7 @@ async def test_stdout_file(self): async with self.connect() as conn: result = await conn.run('echo', input=data, stdout='stdout') - with open('stdout', 'r') as file: + with open('stdout') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) @@ -702,7 +806,7 @@ async def test_stdout_pathlib(self): async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=Path('stdout')) - with open('stdout', 'r') as file: + with open('stdout') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) @@ -720,13 +824,29 @@ async def test_stdout_open_file(self): async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=file) - with open('stdout', 'r') as file: + with open('stdout') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) + @asynctest + async def test_stdout_open_file_keep_open(self): + """Test with stdout redirected to an open file which remains open""" + + data = str(id(self)) + + with open('stdout', 'w') as file: + async with self.connect() as conn: + await conn.run('echo', input=data, stdout=file, recv_eof=False) + await conn.run('echo', input=data, stdout=file, recv_eof=False) + + with open('stdout') as file: + stdout_data = file.read() + + self.assertEqual(stdout_data, 2*data) + @asynctest async def test_stdout_open_binary_file(self): """Test with stdout redirected to an open binary file""" @@ -831,15 +951,32 @@ async def test_stdout_stream(self): 'cat', stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE) - proc1 = await conn.create_process(stdout=proc2.stdin, - stderr=asyncssh.DEVNULL) + async with conn.create_process(input=data, stdout=proc2.stdin): + stdout_data = await proc2.stdout.read() - proc1.stdin.write(data) - proc1.stdin.write_eof() + self.assertEqual(stdout_data, data.encode('ascii')) + + @unittest.skipIf(sys.platform == 'win32', + 'skip asyncio.subprocess tests on Windows') + @asynctest + async def test_stdout_stream_keep_open(self): + """Test with stdout redirected to asyncio stream which remains open""" - stdout_data, _ = await proc2.communicate() + data = str(id(self)) - self.assertEqual(stdout_data, data.encode('ascii')) + async with self.connect() as conn: + proc2 = await asyncio.create_subprocess_shell( + 'cat', stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE) + + await conn.run('echo', input=data, stdout=proc2.stdin, + stderr=asyncssh.DEVNULL, recv_eof=False) + await conn.run('echo', input=data, stdout=proc2.stdin, + stderr=asyncssh.DEVNULL) + + stdout_data = await proc2.stdout.read() + + self.assertEqual(stdout_data, 2*data.encode('ascii')) @asynctest async def test_change_stdout(self): @@ -858,7 +995,7 @@ async def test_change_stdout(self): result = await process.wait() - with open('stdout', 'r') as file: + with open('stdout') as file: stdout_data = file.read() self.assertEqual(stdout_data, 'xxx') @@ -1104,18 +1241,34 @@ async def test_stdout_aiofile(self): data = str(id(self)) - file = open('stdout', 'w') + file = await aiofiles.open('stdout', 'w') async with self.connect() as conn: result = await conn.run('echo', input=data, stdout=file) - with open('stdout', 'r') as file: + with open('stdout') as file: stdout_data = file.read() self.assertEqual(stdout_data, data) self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) + @asynctest + async def test_stdout_aiofile_keep_open(self): + """Test with stdout redirected to an aiofile which remains open""" + + data = str(id(self)) + + async with aiofiles.open('stdout', 'w') as file: + async with self.connect() as conn: + await conn.run('echo', input=data, stdout=file, recv_eof=False) + await conn.run('echo', input=data, stdout=file, recv_eof=False) + + with open('stdout') as file: + stdout_data = file.read() + + self.assertEqual(stdout_data, 2*data) + @asynctest async def test_stdout_binary_aiofile(self): """Test with stdout redirected to an aiofile in binary mode""" @@ -1152,6 +1305,20 @@ async def test_pause_async_file_reader(self): self.assertEqual(result.stdout, data) + @asynctest + async def test_pause_async_file_writer(self): + """Test pausing and resuming writing to an aiofile""" + + data = 4*1024*1024*'*' + + async with aiofiles.open('stdout', 'w') as file: + async with self.connect() as conn: + await conn.run('delay', input=data, stdout=file, + stderr=asyncssh.DEVNULL) + + with open('stdout') as file: + self.assertEqual(file.read(), data) + @unittest.skipIf(sys.platform == 'win32', 'skip pipe tests on Windows') class _TestProcessPipes(_TestProcess): @@ -1229,6 +1396,28 @@ async def test_stdout_pipe(self): self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) + @asynctest + async def test_stdout_pipe_keep_open(self): + """Test with stdout redirected to a pipe which remains open""" + + data = str(id(self)) + + rpipe, wpipe = os.pipe() + + os.write(wpipe, data.encode()) + + async with self.connect() as conn: + await conn.run('echo', input=data, stdout=wpipe, recv_eof=False) + await conn.run('echo', input=data, stdout=wpipe, recv_eof=False) + + os.write(wpipe, data.encode()) + os.close(wpipe) + + stdout_data = os.read(rpipe, 1024) + os.close(rpipe) + + self.assertEqual(stdout_data.decode(), 4*data) + @asynctest async def test_stdout_text_pipe(self): """Test with stdout redirected to a pipe in text mode""" @@ -1250,6 +1439,31 @@ async def test_stdout_text_pipe(self): self.assertEqual(result.stdout, '') self.assertEqual(result.stderr, data) + @asynctest + async def test_stdout_text_pipe_keep_open(self): + """Test with stdout to a pipe in text mode which remains open""" + + data = str(id(self)) + + rpipe, wpipe = os.pipe() + + rpipe = os.fdopen(rpipe, 'r') + wpipe = os.fdopen(wpipe, 'w') + + wpipe.write(data) + + async with self.connect() as conn: + await conn.run('echo', input=data, stdout=wpipe, recv_eof=False) + await conn.run('echo', input=data, stdout=wpipe, recv_eof=False) + + wpipe.write(data) + wpipe.close() + + stdout_data = rpipe.read(1024) + rpipe.close() + + self.assertEqual(stdout_data, 4*data) + @asynctest async def test_stdout_binary_pipe(self): """Test with stdout redirected to a pipe in binary mode""" @@ -1333,50 +1547,55 @@ async def test_stdout_socketpair(self): self.assertEqual(result.stderr, data) @asynctest - async def test_pause_socketpair_reader(self): - """Test pausing and resuming reading from a socketpair""" + async def test_pause_socketpair_pipes(self): + """Test pausing and resuming reading from and writing to pipes""" - data = 4*1024*1024*'*' + data = 4*1024*1024*b'*' sock1, sock2 = socket.socketpair() + sock3, sock4 = socket.socketpair() - _, writer = await asyncio.open_unix_connection(sock=sock1) - writer.write(data.encode()) - writer.close() + _, writer1 = await asyncio.open_unix_connection(sock=sock1) + writer1.write(data) + writer1.close() - async with self.connect() as conn: - result = await conn.run('delay', stdin=sock2, - stderr=asyncssh.DEVNULL) + reader2, writer2 = await asyncio.open_unix_connection(sock=sock4) - self.assertEqual(result.stdout, data) - - @asynctest - async def test_pause_socketpair_writer(self): - """Test pausing and resuming writing to a socketpair""" + async with self.connect() as conn: + process = await conn.create_process('delay', encoding=None, + stdin=sock2, stdout=sock3, + stderr=asyncssh.DEVNULL) - data = 4*1024*1024*'*' + self.assertEqual((await reader2.read()), data) + await process.wait() - rsock1, wsock1 = socket.socketpair() - rsock2, wsock2 = socket.socketpair() + writer2.close() - reader1, writer1 = await asyncio.open_unix_connection(sock=rsock1) - reader2, writer2 = await asyncio.open_unix_connection(sock=rsock2) + @asynctest + async def test_pause_socketpair_streams(self): + """Test pausing and resuming reading from and writing to streams""" - async with self.connect() as conn: - process = await conn.create_process(input=data) + data = 4*1024*1024*b'*' - await asyncio.sleep(1) + sock1, sock2 = socket.socketpair() + sock3, sock4 = socket.socketpair() - await process.redirect_stdout(wsock1) - await process.redirect_stderr(wsock2) + _, writer1 = await asyncio.open_unix_connection(sock=sock1) + writer1.write(data) + writer1.close() - stdout_data, stderr_data = \ - await asyncio.gather(reader1.read(), reader2.read()) + reader2, writer2 = await asyncio.open_unix_connection(sock=sock2) + _, writer3 = await asyncio.open_unix_connection(sock=sock3) + reader4, writer4 = await asyncio.open_unix_connection(sock=sock4) - writer1.close() - writer2.close() + async with self.connect() as conn: + process = await conn.create_process('delay', encoding=None, + stdin=reader2, stdout=writer3, + stderr=asyncssh.DEVNULL) + self.assertEqual((await reader4.read()), data) await process.wait() - self.assertEqual(stdout_data.decode(), data) - self.assertEqual(stderr_data.decode(), data) + writer2.close() + writer3.close() + writer4.close() diff --git a/tests/test_public_key.py b/tests/test_public_key.py index babcfdf..e13471d 100644 --- a/tests/test_public_key.py +++ b/tests/test_public_key.py @@ -1,4 +1,4 @@ -# Copyright (c) 2014-2022 by Ron Frederick and others. +# Copyright (c) 2014-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -36,6 +36,8 @@ import sys import unittest +from cryptography.exceptions import UnsupportedAlgorithm + import asyncssh from asyncssh.asn1 import der_encode, BitString, ObjectIdentifier @@ -53,7 +55,7 @@ from asyncssh.public_key import load_identities from .sk_stub import sk_available, stub_sk, unstub_sk -from .util import bcrypt_available, x509_available +from .util import bcrypt_available, get_test_key, x509_available from .util import make_certificate, run, TempDirTestCase @@ -97,12 +99,6 @@ _openssh_available = _openssh_version != b'' -# GCM & Chacha tests require OpenSSH 6.9 due to a bug in earlier versions: -# https://bugzilla.mindrot.org/show_bug.cgi?id=2366 -_openssh_supports_gcm_chacha = _openssh_version >= b'OpenSSH_6.9' -_openssh_supports_arcfour_blowfish_cast = (_openssh_available and - _openssh_version < b'OpenSSH_7.6') - pkcs1_ciphers = (('aes128-cbc', '-aes128', False), ('aes192-cbc', '-aes192', False), ('aes256-cbc', '-aes256', False), @@ -146,13 +142,8 @@ _openssl_available, False)) openssh_ciphers = ( - ('aes128-gcm@openssh.com', _openssh_supports_gcm_chacha), - ('aes256-gcm@openssh.com', _openssh_supports_gcm_chacha), - ('arcfour', _openssh_supports_arcfour_blowfish_cast), - ('arcfour128', _openssh_supports_arcfour_blowfish_cast), - ('arcfour256', _openssh_supports_arcfour_blowfish_cast), - ('blowfish-cbc', _openssh_supports_arcfour_blowfish_cast), - ('cast128-cbc', _openssh_supports_arcfour_blowfish_cast), + ('aes128-gcm@openssh.com', _openssh_available), + ('aes256-gcm@openssh.com', _openssh_available), ('aes128-cbc', _openssh_available), ('aes192-cbc', _openssh_available), ('aes256-cbc', _openssh_available), @@ -163,8 +154,7 @@ ) if chacha_available: # pragma: no branch - openssh_ciphers += (('chacha20-poly1305@openssh.com', - _openssh_supports_gcm_chacha),) + openssh_ciphers += (('chacha20-poly1305@openssh.com', _openssh_available),) def select_passphrase(cipher, pbe_version=0): @@ -178,7 +168,7 @@ def select_passphrase(cipher, pbe_version=0): 'rc4-40', 'rc4-128'): return 'passphrase'.encode('utf-16-be') else: - return 'passphrase'.encode('utf-8') + return b'passphrase' class _TestPublicKey(TempDirTestCase): @@ -193,6 +183,7 @@ class _TestPublicKey(TempDirTestCase): default_cert_version = '' x509_supported = False generate_args = () + single_cipher = True use_openssh = _openssh_available use_openssl = _openssl_available @@ -319,6 +310,7 @@ def check_private(self, format_name, passphrase=None): else: newkey.write_private_key('list', format_name) newkey.append_private_key('list', format_name) + write_file('list', b'Extra text at end of key list\n', 'ab') keylist = asyncssh.read_private_key_list('list') self.assertEqual(keylist[0].public_data, pubdata) @@ -405,7 +397,7 @@ def check_public(self, format_name): fp = newkey.get_fingerprint(hash_name) if self.use_openssh: # pragma: no branch - keygen_fp = run('ssh-keygen -l -E %s -f sshpub' % hash_name) + keygen_fp = run(f'ssh-keygen -l -E {hash_name} -f sshpub') self.assertEqual(fp, keygen_fp.decode('ascii').split()[1]) with self.assertRaises(ValueError): @@ -501,15 +493,15 @@ def check_certificate(self, cert_type, format_name): def import_pkcs1_private(self, fmt, cipher=None, args=None): """Check import of a PKCS#1 private key""" - format_name = 'pkcs1-%s' % fmt + format_name = f'pkcs1-{fmt}' if self.use_openssl: # pragma: no branch if cipher: - run('openssl %s %s -in priv -inform pem -out new -outform %s ' - '-passout pass:passphrase' % (self.keyclass, args, fmt)) + run(f'openssl {self.keyclass} {args} -in priv -inform pem ' + f'-out new -outform {fmt} -passout pass:passphrase') else: - run('openssl %s -in priv -inform pem -out new -outform %s' % - (self.keyclass, fmt)) + run(f'openssl {self.keyclass} -in priv -inform pem ' + f'-out new -outform {fmt}') else: # pragma: no cover self.privkey.write_private_key('new', format_name, select_passphrase(cipher), cipher) @@ -519,18 +511,18 @@ def import_pkcs1_private(self, fmt, cipher=None, args=None): def export_pkcs1_private(self, fmt, cipher=None, legacy_args=None): """Check export of a PKCS#1 private key""" - format_name = 'pkcs1-%s' % fmt + format_name = f'pkcs1-{fmt}' self.privkey.write_private_key('privout', format_name, select_passphrase(cipher), cipher) if self.use_openssl: # pragma: no branch if cipher: - run('openssl %s %s -in privout -inform %s -out new ' - '-outform pem -passin pass:passphrase' % - (self.keyclass, legacy_args, fmt)) + run(f'openssl {self.keyclass} {legacy_args} -in privout ' + f'-inform {fmt} -out new -outform pem ' + '-passin pass:passphrase') else: - run('openssl %s -in privout -inform %s -out new -outform pem' % - (self.keyclass, fmt)) + run(f'openssl {self.keyclass} -in privout -inform {fmt} ' + '-out new -outform pem') else: # pragma: no cover priv = asyncssh.read_private_key('privout', select_passphrase(cipher)) @@ -541,7 +533,7 @@ def export_pkcs1_private(self, fmt, cipher=None, legacy_args=None): def import_pkcs1_public(self, fmt): """Check import of a PKCS#1 public key""" - format_name = 'pkcs1-%s' % fmt + format_name = f'pkcs1-{fmt}' if (not self.use_openssl or self.keyclass == 'dsa' or _openssl_version < b'OpenSSL 1.0.0'): # pragma: no cover @@ -551,15 +543,15 @@ def import_pkcs1_public(self, fmt): self.pubkey.write_public_key('new', format_name) else: - run('openssl %s -pubin -in pub -inform pem -RSAPublicKey_out ' - '-out new -outform %s' % (self.keyclass, fmt)) + run(f'openssl {self.keyclass} -pubin -in pub -inform pem ' + f'-RSAPublicKey_out -out new -outform {fmt}') self.check_public(format_name) def export_pkcs1_public(self, fmt): """Check export of a PKCS#1 public key""" - format_name = 'pkcs1-%s' % fmt + format_name = f'pkcs1-{fmt}' self.privkey.write_public_key('pubout', format_name) if not self.use_openssl or self.keyclass == 'dsa': # pragma: no cover @@ -567,10 +559,10 @@ def export_pkcs1_public(self, fmt): # only test against ourselves. pub = asyncssh.read_public_key('pubout') - pub.write_public_key('new', 'pkcs1-%s' % fmt) + pub.write_public_key('new', format_name) else: - run('openssl %s -RSAPublicKey_in -in pubout -inform %s -out new ' - '-outform pem' % (self.keyclass, fmt)) + run(f'openssl {self.keyclass} -RSAPublicKey_in -in pubout ' + f'-inform {fmt} -out new -outform pem') self.check_public(format_name) @@ -578,15 +570,15 @@ def import_pkcs8_private(self, fmt, openssl_ok=True, cipher=None, hash_alg=None, pbe_version=None, args=None): """Check import of a PKCS#8 private key""" - format_name = 'pkcs8-%s' % fmt + format_name = f'pkcs8-{fmt}' if self.use_openssl and openssl_ok: # pragma: no branch if cipher: - run('openssl pkcs8 -topk8 %s -in priv -inform pem -out new ' - '-outform %s -passout pass:passphrase' % (args, fmt)) + run(f'openssl pkcs8 -topk8 {args} -in priv -inform pem ' + f'-out new -outform {fmt} -passout pass:passphrase') else: run('openssl pkcs8 -topk8 -nocrypt -in priv -inform pem ' - '-out new -outform %s' % fmt) + f'-out new -outform {fmt}') else: # pragma: no cover self.privkey.write_private_key('new', format_name, select_passphrase(cipher, @@ -600,23 +592,21 @@ def export_pkcs8_private(self, fmt, openssl_ok=True, cipher=None, legacy_args=None): """Check export of a PKCS#8 private key""" - format_name = 'pkcs8-%s' % fmt + format_name = f'pkcs8-{fmt}' self.privkey.write_private_key('privout', format_name, select_passphrase(cipher, pbe_version), cipher, hash_alg, pbe_version) if self.use_openssl and openssl_ok: # pragma: no branch if cipher: - run('openssl pkcs8 %s -in privout -inform %s -out new ' - '-outform pem -passin pass:passphrase' % - (legacy_args, fmt)) + run(f'openssl pkcs8 {legacy_args} -in privout -inform {fmt} ' + '-out new -outform pem -passin pass:passphrase') else: - run('openssl pkcs8 -nocrypt -in privout -inform %s -out new ' - '-outform pem' % fmt) + run(f'openssl pkcs8 -nocrypt -in privout -inform {fmt} ' + '-out new -outform pem') else: # pragma: no cover - priv = asyncssh.read_private_key('privout', - select_passphrase(cipher, - pbe_version)) + priv = asyncssh.read_private_key( + 'privout', select_passphrase(cipher, pbe_version)) priv.write_private_key('new', format_name) self.check_private(format_name) @@ -624,15 +614,15 @@ def export_pkcs8_private(self, fmt, openssl_ok=True, cipher=None, def import_pkcs8_public(self, fmt): """Check import of a PKCS#8 public key""" - format_name = 'pkcs8-%s' % fmt + format_name = f'pkcs8-{fmt}' if self.use_openssl: if _openssl_supports_pkey: run('openssl pkey -pubin -in pub -inform pem -out new ' - '-outform %s' % fmt) + f'-outform {fmt}') else: # pragma: no cover - run('openssl %s -pubin -in pub -inform pem -out new ' - '-outform %s' % (self.keyclass, fmt)) + run(f'openssl {self.keyclass} -pubin -in pub -inform pem ' + f'-out new -outform {fmt}') else: # pragma: no cover self.pubkey.write_public_key('new', format_name) @@ -641,16 +631,16 @@ def import_pkcs8_public(self, fmt): def export_pkcs8_public(self, fmt): """Check export of a PKCS#8 public key""" - format_name = 'pkcs8-%s' % fmt + format_name = f'pkcs8-{fmt}' self.privkey.write_public_key('pubout', format_name) if self.use_openssl: if _openssl_supports_pkey: - run('openssl pkey -pubin -in pubout -inform %s -out new ' - '-outform pem' % fmt) + run(f'openssl pkey -pubin -in pubout -inform {fmt} ' + '-out new -outform pem') else: # pragma: no cover - run('openssl %s -pubin -in pubout -inform %s -out new ' - '-outform pem' % (self.keyclass, fmt)) + run(f'openssl {self.keyclass} -pubin -in pubout ' + f'-inform {fmt} -out new -outform pem') else: # pragma: no cover pub = asyncssh.read_public_key('pubout') pub.write_public_key('new', format_name) @@ -664,8 +654,7 @@ def import_openssh_private(self, openssh_ok=True, cipher=None): shutil.copy('priv', 'new') if cipher: - run('ssh-keygen -p -a 1 -N passphrase -Z %s -o -f new' % - cipher) + run(f'ssh-keygen -p -a 1 -N passphrase -Z {cipher} -o -f new') else: run('ssh-keygen -p -N "" -o -f new') else: # pragma: no cover @@ -767,7 +756,7 @@ def import_rfc4716_certificate(self, cert_type, cert): """Check import of an RFC4716 certificate""" if self.use_openssh: # pragma: no branch - run('ssh-keygen -e -f %s -m rfc4716 > cert' % cert) + run(f'ssh-keygen -e -f {cert} -m rfc4716 > cert') else: # pragma: no cover if cert_type == CERT_TYPE_USER: cert = self.usercert @@ -860,7 +849,7 @@ def check_encode_errors(self): for fmt in ('pkcs1-der', 'pkcs1-pem', 'pkcs8-der', 'pkcs8-pem', 'openssh', 'rfc4716', 'xxx'): - with self.subTest('Encode private from public (%s)' % fmt): + with self.subTest(f'Encode private from public ({fmt})'): with self.assertRaises(asyncssh.KeyExportError): self.pubkey.export_private_key(fmt) @@ -1221,23 +1210,23 @@ def check_decode_errors(self): ] for fmt, data in private_errors: - with self.subTest('Decode private (%s)' % fmt): + with self.subTest(f'Decode private ({fmt})'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_private_key(data) for fmt, data in decrypt_errors: - with self.subTest('Decrypt private (%s)' % fmt): + with self.subTest('fDecrypt private ({fmt})'): with self.assertRaises((asyncssh.KeyEncryptionError, asyncssh.KeyImportError)): asyncssh.import_private_key(data, 'x') for fmt, data in public_errors: - with self.subTest('Decode public (%s)' % fmt): + with self.subTest(f'Decode public ({fmt})'): with self.assertRaises(asyncssh.KeyImportError): asyncssh.import_public_key(data) for fmt, key in keypair_errors: - with self.subTest('Load keypair (%s)' % fmt): + with self.subTest(f'Load keypair ({fmt})'): with self.assertRaises(ValueError): asyncssh.load_keypairs([key]) @@ -1276,26 +1265,40 @@ def check_sign_and_verify(self): for sig_alg in keypair.sig_algorithms: with self.subTest('Good signature', sig_alg=sig_alg): - keypair.set_sig_algorithm(sig_alg) - sig = keypair.sign(data) + try: + keypair.set_sig_algorithm(sig_alg) + sig = keypair.sign(data) + + with self.subTest('Good signature'): + self.assertTrue(self.pubkey.verify(data, sig)) - with self.subTest('Good signature'): - self.assertTrue(self.pubkey.verify(data, sig)) + badsig = bytearray(sig) + badsig[-1] ^= 0xff + badsig = bytes(badsig) - badsig = bytearray(sig) - badsig[-1] ^= 0xff - badsig = bytes(badsig) + with self.subTest('Bad signature'): + self.assertFalse(self.pubkey.verify(data, + badsig)) - with self.subTest('Bad signature'): - self.assertFalse(self.pubkey.verify(data, badsig)) + if sig_alg.startswith(b'webauthn-'): + idx = sig.rfind(b'ssh:') + badpfx = bytearray(sig) + badpfx[idx] = ord('x') + badpfx = bytes(badpfx) + + with self.subTest('Bad prefix'): + self.assertFalse(self.pubkey.verify(data, + badpfx)) + except UnsupportedAlgorithm: # pragma: no cover + pass with self.subTest('Missing signature'): self.assertFalse(self.pubkey.verify( - data, String(self.pubkey.algorithm))) + data, String(self.pubkey.sig_algorithms[0]))) with self.subTest('Empty signature'): self.assertFalse(self.pubkey.verify( - data, String(self.pubkey.algorithm) + String(b''))) + data, String(self.pubkey.sig_algorithms[0]) + String(b''))) with self.subTest('Sign with bad algorithm'): with self.assertRaises(ValueError): @@ -1307,7 +1310,7 @@ def check_sign_and_verify(self): with self.subTest('Sign with public key'): with self.assertRaises(ValueError): - self.pubkey.sign(data, self.pubkey.algorithm) + self.pubkey.sign(data, self.pubkey.sig_algorithms[0]) def check_set_certificate(self): """Check setting certificate on existing keypair""" @@ -1320,7 +1323,7 @@ def check_set_certificate(self): keypair = asyncssh.load_keypairs((keypair, self.usercert))[0] self.assertEqual(keypair.public_data, self.usercert.public_data) - key2 = asyncssh.generate_private_key('ssh-rsa') + key2 = get_test_key('ssh-rsa', 1) with self.assertRaises(ValueError): asyncssh.load_keypairs((key2, self.usercert)) @@ -1534,7 +1537,7 @@ def check_comment(self): keypair = asyncssh.load_keypairs(('key', None))[0] self.assertEqual(keypair.get_comment(), 'pub_comment') - key2 = asyncssh.generate_private_key('ssh-rsa') + key2 = get_test_key('ssh-rsa', 1) with self.assertRaises(ValueError): asyncssh.load_keypairs((key2, 'pub')) @@ -1560,10 +1563,10 @@ def check_pkcs1_private(self): for cipher, args, legacy in pkcs1_ciphers: legacy_args = _openssl_legacy if legacy else '' - with self.subTest('Import PKCS#1 PEM private (%s)' % cipher): + with self.subTest(f'Import PKCS#1 PEM private ({cipher})'): self.import_pkcs1_private('pem', cipher, legacy_args + args) - with self.subTest('Export PKCS#1 PEM private (%s)' % cipher): + with self.subTest(f'Export PKCS#1 PEM private ({cipher})'): self.export_pkcs1_private('pem', cipher, legacy_args) def check_pkcs1_public(self): @@ -1600,28 +1603,31 @@ def check_pkcs8_private(self): openssl_ok, legacy in pkcs8_ciphers: legacy_args = _openssl_legacy if legacy else '' - with self.subTest('Import PKCS#8 PEM private (%s-%s-v%s)' % - (cipher, hash_alg, pbe_version)): + with self.subTest(f'Import PKCS#8 PEM private ({cipher}-' + f'{hash_alg}-v{pbe_version})'): self.import_pkcs8_private('pem', openssl_ok, cipher, hash_alg, pbe_version, legacy_args + args) - with self.subTest('Export PKCS#8 PEM private (%s-%s-v%s)' % - (cipher, hash_alg, pbe_version)): + with self.subTest(f'Export PKCS#8 PEM private ({cipher}-' + f'{hash_alg}-v{pbe_version})'): self.export_pkcs8_private('pem', openssl_ok, cipher, hash_alg, pbe_version, legacy_args) - with self.subTest('Import PKCS#8 DER private (%s-%s-v%s)' % - (cipher, hash_alg, pbe_version)): + with self.subTest(f'Import PKCS#8 DER private ({cipher}-' + f'{hash_alg}-v{pbe_version})'): self.import_pkcs8_private('der', openssl_ok, cipher, hash_alg, pbe_version, legacy_args + args) - with self.subTest('Export PKCS#8 DER private (%s-%s-v%s)' % - (cipher, hash_alg, pbe_version)): + with self.subTest(f'Export PKCS#8 DER private ({cipher}-' + f'{hash_alg}-v{pbe_version})'): self.export_pkcs8_private('der', openssl_ok, cipher, hash_alg, pbe_version, legacy_args) + if self.single_cipher: + break + def check_pkcs8_public(self): """Check PKCS#8 public key format""" @@ -1648,12 +1654,15 @@ def check_openssh_private(self): if bcrypt_available: # pragma: no branch for cipher, openssh_ok in openssh_ciphers: - with self.subTest('Import OpenSSH private (%s)' % cipher): + with self.subTest(f'Import OpenSSH private ({cipher})'): self.import_openssh_private(openssh_ok, cipher) - with self.subTest('Export OpenSSH private (%s)' % cipher): + with self.subTest(f'Export OpenSSH private ({cipher})'): self.export_openssh_private(openssh_ok, cipher) + if self.single_cipher: + break + def check_openssh_public(self): """Check OpenSSH public key format""" @@ -2002,7 +2011,7 @@ def test_keys(self): for alg_name, kwargs in self.generate_args: with self.subTest(alg_name=alg_name, **kwargs): - self.privkey = asyncssh.generate_private_key( + self.privkey = get_test_key( alg_name, comment='comment', **kwargs) self.privkey.write_private_key('priv', self.base_format) @@ -2012,7 +2021,7 @@ def test_keys(self): self.pubkey.write_public_key('pub', self.base_format) self.pubkey.write_public_key('sshpub', 'openssh') - self.privca = asyncssh.generate_private_key(alg_name, **kwargs) + self.privca = get_test_key(alg_name, 1, **kwargs) self.privca.write_private_key('privca', self.base_format) self.pubca = self.privca.convert_to_public() @@ -2022,8 +2031,10 @@ def test_keys(self): self.pubkey, 'name', comment='user_comment') self.usercert.write_certificate('usercert') + hostcert_sig_alg = self.privca.sig_algorithms[0].decode() self.hostcert = self.privca.generate_host_certificate( - self.pubkey, 'name', comment='host_comment') + self.pubkey, 'name', sig_alg=hostcert_sig_alg, + comment='host_comment') self.hostcert.write_certificate('hostcert') for f in ('priv', 'privca'): @@ -2109,6 +2120,7 @@ class TestDSA(_TestPublicKey): default_cert_version = 'ssh-dss-cert-v01@openssh.com' x509_supported = x509_available generate_args = (('ssh-dss', {}),) + use_openssh = False class TestRSA(_TestPublicKey): @@ -2156,6 +2168,7 @@ class TestEd25519(_TestPublicKey): x509_supported = x509_available default_cert_version = 'ssh-ed25519-cert-v01@openssh.com' generate_args = (('ssh-ed25519', {}),) + single_cipher = False use_openssh = False use_openssl = _openssl_supports_pkey @@ -2234,9 +2247,9 @@ def test_public_key(self): self.assertEqual(bool(get_x509_certificate_algs()), x509_available) def test_public_key_algorithm_mismatch(self): - """Test algorihm mismatch in SSH public key""" + """Test algorithm mismatch in SSH public key""" - privkey = asyncssh.generate_private_key('ssh-rsa') + privkey = get_test_key('ssh-rsa') keydata = privkey.export_public_key('openssh') keydata = b'ssh-dss ' + keydata.split(None, 1)[1] @@ -2255,14 +2268,14 @@ def test_pad_error(self): pkcs1_decrypt(b'', b'AES-128-CBC', os.urandom(16), 'x') def test_ec_explicit(self): - """Test EC certificate with explcit parameters""" + """Test EC certificate with explicit parameters""" if _openssl_available: # pragma: no branch for curve in ('secp256r1', 'secp384r1', 'secp521r1'): with self.subTest('Import EC key with explicit parameters', curve=curve): - run('openssl ecparam -out priv -noout -genkey -name %s ' - '-param_enc explicit' % curve) + run('openssl ecparam -out priv -noout -genkey ' + f'-name {curve} -param_enc explicit') asyncssh.read_private_key('priv') @unittest.skipIf(not _openssl_available, "openssl isn't available") @@ -2290,9 +2303,9 @@ def test_generate_errors(self): with self.assertRaises(asyncssh.KeyGenerationError): asyncssh.generate_private_key(alg_name, **kwargs) - privkey = asyncssh.generate_private_key('ssh-rsa') + privkey = get_test_key('ssh-rsa') pubkey = privkey.convert_to_public() - privca = asyncssh.generate_private_key('ssh-rsa') + privca = get_test_key('ssh-rsa', 1) with self.assertRaises(asyncssh.KeyGenerationError): privca.generate_user_certificate(pubkey, 'name', version=0) @@ -2328,7 +2341,7 @@ def test_generate_errors(self): def test_rsa_encrypt_error(self): """Test RSA encryption error""" - privkey = asyncssh.generate_private_key('ssh-rsa', 2048) + privkey = get_test_key('ssh-rsa', 2048) pubkey = privkey.convert_to_public() self.assertIsNone(pubkey.encrypt(os.urandom(256), pubkey.algorithm)) @@ -2336,7 +2349,7 @@ def test_rsa_encrypt_error(self): def test_rsa_decrypt_error(self): """Test RSA decryption error""" - privkey = asyncssh.generate_private_key('ssh-rsa', 2048) + privkey = get_test_key('ssh-rsa', 2048) self.assertIsNone(privkey.decrypt(b'', privkey.algorithm)) @@ -2344,10 +2357,10 @@ def test_rsa_decrypt_error(self): def test_x509_certificate_hashes(self): """Test X.509 certificate hash algorithms""" - privkey = asyncssh.generate_private_key('ssh-rsa') + privkey = get_test_key('ssh-rsa') pubkey = privkey.convert_to_public() - for hash_alg in ('sha1', 'sha256', 'sha512'): + for hash_alg in ('sha256', 'sha512'): cert = privkey.generate_x509_user_certificate( pubkey, 'OU=user', hash_alg=hash_alg) diff --git a/tests/test_saslprep.py b/tests/test_saslprep.py index 5e39ff5..ca29f3a 100644 --- a/tests/test_saslprep.py +++ b/tests/test_saslprep.py @@ -37,7 +37,7 @@ def test_unassigned(self): """Test passing strings with unassigned code points""" for s in ('\u0221', '\u038b', '\u0510', '\u070e', '\u0900', '\u0a00'): - with self.assertRaises(SASLPrepError, msg='U+%08x' % ord(s)): + with self.assertRaises(SASLPrepError, msg=f'U+{ord(s):08x}'): saslprep('abc' + s + 'def') def test_map_to_nothing(self): @@ -45,19 +45,19 @@ def test_map_to_nothing(self): for s in ('\u00ad', '\u034f', '\u1806', '\u200c', '\u2060', '\ufe00'): self.assertEqual(saslprep('abc' + s + 'def'), 'abcdef', - msg='U+%08x' % ord(s)) + msg=f'U+{ord(s):08x}') def test_map_to_whitespace(self): """Test passing strings with characters that map to whitespace""" for s in ('\u00a0', '\u1680', '\u2000', '\u202f', '\u205f', '\u3000'): self.assertEqual(saslprep('abc' + s + 'def'), 'abc def', - msg='U+%08x' % ord(s)) + msg=f'U+{ord(s):08x}') def test_normalization(self): """Test Unicode normalization form KC conversions""" for (s, n) in (('\u00aa', 'a'), ('\u2168', 'IX')): self.assertEqual(saslprep('abc' + s + 'def'), 'abc' + n + 'def', - msg='U+%08x' % ord(s)) + msg=f'U+{ord(s):08x}') def test_prohibited(self): """Test passing strings with prohibited characters""" @@ -65,7 +65,7 @@ def test_prohibited(self): '\u2028', '\u202a', '\u206a', '\u2ff0', '\u2ffb', '\ud800', '\udfff', '\ue000', '\ufdd0', '\ufef9', '\ufffc', '\uffff', '\U0001d173', '\U000E0001', '\U00100000', '\U0010fffd'): - with self.assertRaises(SASLPrepError, msg='U+%08x' % ord(s)): + with self.assertRaises(SASLPrepError, msg=f'U+{ord(s):08x}'): saslprep('abc' + s + 'def') def test_bidi(self): diff --git a/tests/test_sftp.py b/tests/test_sftp.py index 0fe7e65..47fdde0 100644 --- a/tests/test_sftp.py +++ b/tests/test_sftp.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2022 by Ron Frederick and others. +# Copyright (c) 2015-2025 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -20,6 +20,7 @@ """Unit tests for AsyncSSH SFTP client and server""" +import asyncio import errno import functools import os @@ -66,8 +67,12 @@ from asyncssh import FILEXFER_ATTR_BITS_READONLY, FILEXFER_ATTR_KNOWN_TEXT from asyncssh import FX_OK, scp +from asyncssh.misc import make_sparse_file + from asyncssh.packet import SSHPacket, String, UInt32 -from asyncssh.sftp import LocalFile, SFTPHandler, SFTPServerHandler + +from asyncssh.sftp import SAFE_SFTP_READ_LEN, SAFE_SFTP_WRITE_LEN +from asyncssh.sftp import LocalFile, SFTPHandler, SFTPLimits, SFTPServerHandler from .server import ServerTestCase from .util import asynctest @@ -252,7 +257,7 @@ async def _process_packet(self, pkttype, pktid, packet): class _CheckPropSFTPServer(SFTPServer): """Return an FTP server which checks channel properties""" - def listdir(self, path): + def listdir(self, _path): """List the contents of a directory""" if self.channel.get_connection() == self.connection: # pragma: no branch @@ -291,17 +296,17 @@ class _IOErrorSFTPServer(SFTPServer): """Return an I/O error during file writing""" async def read(self, file_obj, offset, size): - """Return an error for reads past 64 KB in a file""" + """Return an error for reads past 4 MB in a file""" - if offset >= 65536: + if offset >= 4*1024*1024: raise SFTPFailure('I/O error') else: return super().read(file_obj, offset, size) async def write(self, file_obj, offset, data): - """Return an error for writes past 64 KB in a file""" + """Return an error for writes past 4 MB in a file""" - if offset >= 65536: + if offset >= 4*1024*1024: raise SFTPFailure('I/O error') else: super().write(file_obj, offset, data) @@ -349,7 +354,7 @@ class _FileTypeSFTPServer(SFTPServer): (FILEXFER_TYPE_BLOCK_DEVICE, stat.S_IFBLK), (FILEXFER_TYPE_FIFO, stat.S_IFIFO)) - def listdir(self, path): + def listdir(self, _path): """List the contents of a directory""" return [SFTPName(str(filetype).encode('ascii'), @@ -360,9 +365,11 @@ def listdir(self, path): class _LongnameSFTPServer(SFTPServer): """Return a fixed set of files in response to a listdir request""" - def listdir(self, path): + def listdir(self, _path): """List the contents of a directory""" + # pylint: disable=no-self-use + return list((b'.', b'..', SFTPName(b'.file'), @@ -416,7 +423,7 @@ class _ChownSFTPServer(SFTPServer): _ownership = {} def setstat(self, path, attrs): - """Set attributes of a file or directory""" + """Get attributes of a file or directory, following symlinks""" self._ownership[self.map_path(path)] = \ (attrs.uid, attrs.gid, attrs.owner, attrs.group) @@ -458,11 +465,22 @@ async def stat(self, path): else: raise SFTPError(99, exc.strerror) from None + async def lstat(self, path): + """Get attributes of a local file, directory, or symlink""" + + return SFTPAttrs.from_local(super().lstat(path)) + async def fstat(self, file_obj): """Get attributes of an open file""" return SFTPAttrs.from_local(super().fstat(file_obj)) + async def scandir(self, path): + """Return names and attributes of the files in a local directory""" + + async for name in super().scandir(path): + yield name + class _AsyncSFTPServer(SFTPServer): """Implement all SFTP callbacks as async methods""" @@ -505,19 +523,24 @@ async def fstat(self, file_obj): return super().fstat(file_obj) async def setstat(self, path, attrs): - """Set attributes of a file or directory""" + """Set attributes of a file or directory, following symlinks""" super().setstat(path, attrs) + async def lsetstat(self, path, attrs): + """Set attributes of a file, directory, or symlink""" + + super().lsetstat(path, attrs) + async def fsetstat(self, file_obj, attrs): """Set attributes of an open file""" super().fsetstat(file_obj, attrs) - async def listdir(self, path): - """List the contents of a directory""" + def scandir(self, path): + """Scan the contents of a directory""" - return super().listdir(path) + return super().scandir(path) async def remove(self, path): """Remove a file or symbolic link""" @@ -594,6 +617,11 @@ async def fsync(self, file_obj): super().fsync(file_obj) + async def exit(self): + """Shut down this SFTP server""" + + super().exit() + class _CheckSFTP(ServerTestCase): """Utility functions for AsyncSSH SFTP unit tests""" @@ -611,7 +639,7 @@ def setUpClass(cls): except OSError: # pragma: no cover cls._symlink_supported = False - def _create_file(self, name, data=(), mode=None, utime=None): + def _create_file(self, name, data=(), offsets=(0,), mode=None, utime=None): """Create a test file""" if data == (): @@ -620,7 +648,11 @@ def _create_file(self, name, data=(), mode=None, utime=None): binary = 'b' if isinstance(data, bytes) else '' with open(name, 'w' + binary) as f: - f.write(data) + make_sparse_file(f) + + for offset in offsets: + f.seek(offset) + f.write(data) if mode is not None: os.chmod(name, mode) @@ -638,6 +670,7 @@ def _check_attr(self, name1, name2, follow_symlinks, check_atime): self.assertEqual(stat.S_IMODE(attrs1.st_mode), stat.S_IMODE(attrs2.st_mode)) + self.assertEqual(attrs1.st_size, attrs2.st_size) self.assertEqual(int(attrs1.st_mtime), int(attrs2.st_mtime)) if check_atime: @@ -654,6 +687,22 @@ def _check_file(self, name1, name2, preserve=False, follow_symlinks=False, with open(name2, 'rb') as file2: self.assertEqual(file1.read(), file2.read()) + async def _check_sparse_file(self, name1, name2): + """Check if two sparse files are equal""" + + size1 = os.stat(name1).st_size + size2 = os.stat(name2).st_size + self.assertEqual(size1, size2) + + with open(name1, 'rb') as file1: + with open(name2, 'rb') as file2: + ranges1 = [range async for range in + LocalFile(file1).request_ranges(0, size1)] + ranges2 = [range async for range in + LocalFile(file2).request_ranges(0, size2)] + + self.assertEqual(ranges1, ranges2) + def _check_stat(self, sftp_stat, local_stat): """Check if file attributes are equal""" @@ -683,7 +732,12 @@ def _check_stat_v4(self, sftp_stat, local_stat): def _check_link(self, link, target): """Check if a symlink points to the right target""" - self.assertEqual(os.readlink(link), target) + link = os.readlink(link) + + if link.startswith('\\\\?\\'): # pragma: no cover + link = link[4:] + + self.assertEqual(Path(link).resolve(), Path(target).resolve()) class _TestSFTP(_CheckSFTP): @@ -727,6 +781,104 @@ async def test_copy(self, sftp): finally: remove('src dst') + @sftp_test + async def test_sparse_copy(self, sftp): + """Test putting a sparse file over SFTP""" + + for method in ('get', 'put', 'copy'): + with self.subTest(method=method): + try: + self._create_file( + 'src', offsets=(i*1024*1024 for i in + range(24, 3840, 24))) + await getattr(sftp, method)('src', 'dst') + await self._check_sparse_file('src', 'dst') + finally: + remove('src dst') + + @sftp_test + async def test_empty_request_range(self, sftp): + """Test getting ranges from an empty file""" + + try: + self._create_file('file', data=b'') + + async with sftp.open('file', 'rb') as f: + result = [data_range async for data_range in + f.request_ranges(0, 0)] + self.assertEqual(result, []) + finally: + remove('file') + + @sftp_test + async def test_nonsparse_put(self, sftp): + """Test putting a sparse file over SFTP with sparse mode disabled""" + + try: + self._create_file( + 'src', offsets=(i*1024*1024 for i in range(24, 72, 24))) + await sftp.put('src', 'dst', sparse=False) + self._check_file('src', 'dst') + finally: + remove('src dst') + + @sftp_test + async def test_copy_max_requests(self, sftp): + """Test copying a file over SFTP with max requests set""" + + for method in ('get', 'put', 'copy'): + for src in ('src', b'src', Path('src')): + with self.subTest(method=method, src=type(src)): + try: + self._create_file('src', 16*1024*1024*'\0') + await getattr(sftp, method)(src, 'dst', + max_requests=4) + self._check_file('src', 'dst') + finally: + remove('src dst') + + def test_copy_non_remote(self): + """Test copying without using remote_copy function""" + + @sftp_test + async def _test_copy_non_remote(self, sftp): + """Test copying without using remote_copy function""" + + for method in ('copy', 'mcopy'): + with self.subTest(method=method): + try: + self._create_file('src') + await getattr(sftp, method)('src', 'dst') + self._check_file('src', 'dst') + finally: + remove('src dst') + + with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): + # pylint: disable=no-value-for-parameter + _test_copy_non_remote(self) + + def test_copy_remote_only(self): + """Test copying while allowing only remote copy""" + + @sftp_test + async def _test_copy_remote_only(self, sftp): + """Test copying with only remote copy allowed""" + + for method in ('copy', 'mcopy'): + with self.subTest(method=method): + try: + self._create_file('src') + + with self.assertRaises(SFTPOpUnsupported): + await getattr(sftp, method)('src', 'dst', + remote_only=True) + finally: + remove('src') + + with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): + # pylint: disable=no-value-for-parameter + _test_copy_remote_only(self) + @sftp_test async def test_copy_progress(self, sftp): """Test copying a file over SFTP with progress reporting""" @@ -748,7 +900,9 @@ def _report_progress(_srcpath, _dstpath, bytes_copied, _total_bytes): progress_handler=_report_progress) self._check_file('src', 'dst') - self.assertEqual(len(reports), (size // 8192) + 1) + if method != 'copy': + self.assertEqual(len(reports), (size // 8192) + 1) + self.assertEqual(reports[-1], size) finally: remove('src dst') @@ -766,6 +920,43 @@ async def test_copy_preserve(self, sftp): finally: remove('src dst') + @unittest.skipIf(sys.platform == 'win32', 'skip lsetstat tests on Windows') + @sftp_test + async def test_copy_preserve_link(self, sftp): + """Test copying a symlink with preserved attributes over SFTP""" + + for method in ('get', 'put', 'copy'): + with self.subTest(method=method): + try: + os.symlink('file', 'link1') + os.utime('link1', times=(1, 2), follow_symlinks=False) + await getattr(sftp, method)( + 'link1', 'link2', preserve=True, follow_symlinks=False) + self.assertEqual(os.lstat('link2').st_mtime, 2) + finally: + remove('link1 link2') + + @unittest.skipIf(sys.platform == 'win32', 'skip lsetstat tests on Windows') + def test_copy_preserve_link_unsupported(self): + """Test preserving symlink attributes over SFTP without lsetstat""" + + @sftp_test + async def _lsetstat_unsupported(self, sftp): + """Try copying link attributes without lsetstat""" + + try: + os.symlink('file', 'link1') + os.utime('link1', times=(1, 2), follow_symlinks=False) + await sftp.put('link1', 'link2', preserve=True, + follow_symlinks=False) + self.assertNotEqual(int(os.lstat('link2').st_mtime), 2) + finally: + remove('link1 link2') + + with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): + # pylint: disable=no-value-for-parameter + _lsetstat_unsupported(self) + @sftp_test async def test_copy_recurse(self, sftp): """Test recursively copying a directory over SFTP""" @@ -830,13 +1021,32 @@ async def test_copy_follow_symlinks(self, sftp): finally: remove('src dst link') + @sftp_test + async def test_copy_recurse_follow_symlinks(self, sftp): + """Test recursively copying over SFTP while following symlinks""" + + if not self._symlink_supported: # pragma: no cover + raise unittest.SkipTest('symlink not available') + + for method in ('get', 'put', 'copy'): + with self.subTest(method=method): + try: + os.mkdir('src') + self._create_file('src/file1') + os.symlink('file1', 'src/file2') + await getattr(sftp, method)('src', 'dst', recurse=True, + follow_symlinks=True) + self._check_file('src/file1', 'dst/file2') + finally: + remove('src dst') + @sftp_test async def test_copy_invalid_name(self, sftp): """Test copying a file with an invalid name over SFTP""" for method in ('get', 'put', 'copy', 'mget', 'mput', 'mcopy'): with self.subTest(method=method): - with self.assertRaises((FileNotFoundError, SFTPNoSuchFile, + with self.assertRaises((OSError, SFTPNoSuchFile, SFTPFailure, UnicodeDecodeError)): await getattr(sftp, method)(b'\xff') @@ -1053,6 +1263,56 @@ def err_handler(exc): finally: remove('src1 src2 dst') + def test_remote_copy_unsupported(self): + """Test remote copy on a server which doesn't support it""" + + @sftp_test + async def _test_remote_copy_unsupported(self, sftp): + """Test remote copy not being supported""" + + try: + self._create_file('src') + + with self.assertRaises(SFTPOpUnsupported): + await sftp.remote_copy('src', 'dst') + finally: + remove('src') + + with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): + # pylint: disable=no-value-for-parameter + _test_remote_copy_unsupported(self) + + @sftp_test + async def test_remote_copy_arguments(self, sftp): + """Test remote copy arguments""" + + try: + self._create_file('src', os.urandom(2*1024*1024)) + + async with sftp.open('src', 'rb') as src: + async with sftp.open('dst', 'wb') as dst: + await sftp.remote_copy(src, dst, 0, 1024*1024, 0) + await sftp.remote_copy(src, dst, 1024*1024, 0, 1024*1024) + + self._check_file('src', 'dst') + finally: + remove('src dst') + + @sftp_test + async def test_remote_copy_closed_file(self, sftp): + """Test remote copy of a closed file""" + + try: + self._create_file('file') + + async with sftp.open('file', 'rb') as f: + await f.close() + + with self.assertRaises(ValueError): + await sftp.remote_copy(f, f) + finally: + remove('file') + @sftp_test async def test_glob(self, sftp): """Test a glob pattern match over SFTP""" @@ -1066,7 +1326,11 @@ async def test_glob(self, sftp): (['file*/*2'], ['filedir/file2', 'filedir/filedir2']), (['file*/*[3-9]'], ['filedir/file3']), (['**/file[12]'], ['file1', 'filedir/file2']), - (['**/file*/'], ['filedir', 'filedir/filedir2']), + (['**/file*/'], ['filedir/', 'filedir/filedir2/']), + (['filedir/**'], ['filedir', 'filedir/file2', + 'filedir/file3', 'filedir/filedir2', + 'filedir/filedir2/file4', + 'filedir/filedir2/file5']), ('filedir/file2', ['filedir/file2']), ('./filedir/file2', ['./filedir/file2']), ('filedir/file*', ['filedir/file2', 'filedir/file3', @@ -1087,7 +1351,8 @@ async def test_glob(self, sftp): ('filedir/filedir*/file*', ['filedir/filedir2/file4', 'filedir/filedir2/file5']), ('./**/filedir2/file4', ['./filedir/filedir2/file4']), - ('**/filedir2/file4', ['filedir/filedir2/file4'])) + ('**/filedir2/file4', ['filedir/filedir2/file4']), + (['file1', '**/file1'], ['file1'])) try: os.mkdir('filedir') @@ -1100,7 +1365,7 @@ async def test_glob(self, sftp): for pattern, matches in glob_tests: with self.subTest(pattern=pattern): - self.assertEqual(sorted((await sftp.glob(pattern))), + self.assertEqual(sorted(await sftp.glob(pattern)), matches) self.assertEqual((await sftp.glob([b'fil*1', 'fil*dir'])), @@ -1109,11 +1374,28 @@ async def test_glob(self, sftp): remove('file1 filedir') @sftp_test - async def test_glob_error(self, sftp): - """Test a glob pattern match error over SFTP""" + async def test_glob_errors(self, sftp): + """Test glob pattern match errors over SFTP""" - with self.assertRaises(SFTPNoSuchFile): - await sftp.glob('file*') + _glob_errors = ( + 'file*', + 'dir/file1/*', + 'dir*/file1/*', + 'dir/dir1/*') + + try: + os.mkdir('dir') + self._create_file('dir/file1') + os.mkdir('dir/dir1') + os.chmod('dir/dir1', 0) + + for pattern in _glob_errors: + with self.subTest(pattern=pattern): + with self.assertRaises(SFTPNoSuchFile): + await sftp.glob(pattern) + finally: + os.chmod('dir/dir1', 0o700) + remove('dir') @sftp_test_v4 async def test_glob_error_v4(self, sftp): @@ -1165,29 +1447,29 @@ async def test_stat(self, sftp): with self.assertRaises(SFTPNoSuchFile): await sftp.stat('badlink') - self.assertTrue((await sftp.isdir('dir'))) - self.assertFalse((await sftp.isdir('file'))) + self.assertTrue(await sftp.isdir('dir')) + self.assertFalse(await sftp.isdir('file')) if self._symlink_supported: # pragma: no branch - self.assertFalse((await sftp.isdir('badlink'))) - self.assertTrue((await sftp.isdir('dirlink'))) - self.assertFalse((await sftp.isdir('filelink'))) + self.assertFalse(await sftp.isdir('badlink')) + self.assertTrue(await sftp.isdir('dirlink')) + self.assertFalse(await sftp.isdir('filelink')) - self.assertFalse((await sftp.isfile('dir'))) - self.assertTrue((await sftp.isfile('file'))) + self.assertFalse(await sftp.isfile('dir')) + self.assertTrue(await sftp.isfile('file')) if self._symlink_supported: # pragma: no branch - self.assertFalse((await sftp.isfile('badlink'))) - self.assertFalse((await sftp.isfile('dirlink'))) - self.assertTrue((await sftp.isfile('filelink'))) + self.assertFalse(await sftp.isfile('badlink')) + self.assertFalse(await sftp.isfile('dirlink')) + self.assertTrue(await sftp.isfile('filelink')) - self.assertFalse((await sftp.islink('dir'))) - self.assertFalse((await sftp.islink('file'))) + self.assertFalse(await sftp.islink('dir')) + self.assertFalse(await sftp.islink('file')) if self._symlink_supported: # pragma: no branch - self.assertTrue((await sftp.islink('badlink'))) - self.assertTrue((await sftp.islink('dirlink'))) - self.assertTrue((await sftp.islink('filelink'))) + self.assertTrue(await sftp.islink('badlink')) + self.assertTrue(await sftp.islink('dirlink')) + self.assertTrue(await sftp.islink('filelink')) finally: remove('dir file badlink dirlink filelink') @@ -1217,6 +1499,33 @@ async def test_lstat_v4(self, sftp): finally: remove('link') + @sftp_test_v6 + async def test_lstat_v6(self, sftp): + """Test getting attributes on a link with SFTPv6""" + + if not self._symlink_supported: # pragma: no cover + raise unittest.SkipTest('symlink not available') + + try: + os.symlink('file', 'link') + self._check_stat_v4((await sftp.lstat('link')), os.lstat('link')) + finally: + remove('link') + + @sftp_test + async def test_lstat_via_stat(self, sftp): + """Test getting attributes on a link by disabling follow_symlinks""" + + if not self._symlink_supported: # pragma: no cover + raise unittest.SkipTest('symlink not available') + + try: + os.symlink('file', 'link') + self._check_stat((await sftp.stat('link', follow_symlinks=False)), + os.lstat('link')) + finally: + remove('link') + @sftp_test async def test_setstat(self, sftp): """Test setting attributes on a file""" @@ -1271,6 +1580,63 @@ async def test_setstat_invalid_owner_group_v6(self, sftp): finally: remove('file') + @unittest.skipIf(sys.platform == 'win32', 'skip lsetstat tests on Windows') + @sftp_test + async def test_lsetstat(self, sftp): + """Test setting attributes on a link""" + + try: + os.symlink('file', 'link') + + await sftp.setstat('link', SFTPAttrs(atime=1, mtime=2), + follow_symlinks=False) + + stat_result = os.lstat('link') + self.assertEqual(stat_result.st_atime, 1) + self.assertEqual(stat_result.st_mtime, 2) + finally: + remove('link') + + @unittest.skipIf(sys.platform == 'win32', 'skip lsetstat tests on Windows') + @sftp_test_v4 + async def test_lsetstat_v4(self, sftp): + """Test setting attributes on a link""" + + try: + os.symlink('file', 'link') + + await sftp.setstat('link', SFTPAttrs(atime=1), + follow_symlinks=False) + + self.assertEqual(os.lstat('link').st_atime, 1) + + await sftp.setstat('link', SFTPAttrs(mtime=2), + follow_symlinks=False) + + self.assertEqual(os.lstat('link').st_mtime, 2) + finally: + remove('link') + + @unittest.skipIf(sys.platform == 'win32', 'skip lsetstat tests on Windows') + @sftp_test_v6 + async def test_lsetstat_v6(self, sftp): + """Test setting attributes on a link""" + + try: + os.symlink('file', 'link') + + await sftp.setstat('link', SFTPAttrs(atime=1), + follow_symlinks=False) + + self.assertEqual(os.lstat('link').st_atime, 1) + + await sftp.setstat('link', SFTPAttrs(mtime=2), + follow_symlinks=False) + + self.assertEqual(os.lstat('link').st_mtime, 2) + finally: + remove('link') + @unittest.skipIf(sys.platform == 'win32', 'skip statvfs tests on Windows') @sftp_test async def test_statvfs(self, sftp): @@ -1422,8 +1788,8 @@ async def test_utime_v4(self, sftp): self.assertEqual(stat_result.st_mtime_ns, 2250000000) self.assertEqual((await sftp.getatime('file')), 1.0) self.assertEqual((await sftp.getatime_ns('file')), 1000000000) - self.assertIsNotNone((await sftp.getcrtime('file'))) - self.assertIsNotNone((await sftp.getcrtime_ns('file'))) + self.assertIsNotNone(await sftp.getcrtime('file')) + self.assertIsNotNone(await sftp.getcrtime_ns('file')) self.assertEqual((await sftp.getmtime('file')), 2.25) self.assertEqual((await sftp.getmtime_ns('file')), 2250000000) @@ -1436,8 +1802,8 @@ async def test_utime_v4(self, sftp): self.assertEqual(stat_result.st_mtime_ns, 4750000000) self.assertEqual((await sftp.getatime('file')), 3.5) self.assertEqual((await sftp.getatime_ns('file')), 3500000000) - self.assertIsNotNone((await sftp.getcrtime('file'))) - self.assertIsNotNone((await sftp.getcrtime_ns('file'))) + self.assertIsNotNone(await sftp.getcrtime('file')) + self.assertIsNotNone(await sftp.getcrtime_ns('file')) self.assertEqual((await sftp.getmtime('file')), 4.75) self.assertEqual((await sftp.getmtime_ns('file')), 4750000000) finally: @@ -1450,8 +1816,8 @@ async def test_exists(self, sftp): try: self._create_file('file1') - self.assertTrue((await sftp.exists('file1'))) - self.assertFalse((await sftp.exists('file2'))) + self.assertTrue(await sftp.exists('file1')) + self.assertFalse(await sftp.exists('file2')) finally: remove('file1') @@ -1465,8 +1831,8 @@ async def test_lexists(self, sftp): try: os.symlink('file', 'link1') - self.assertTrue((await sftp.lexists('link1'))) - self.assertFalse((await sftp.lexists('link2'))) + self.assertTrue(await sftp.lexists('link1')) + self.assertFalse(await sftp.lexists('link2')) finally: remove('link1') @@ -1586,7 +1952,7 @@ async def test_listdir(self, sftp): os.mkdir('dir') self._create_file('dir/file1') self._create_file('dir/file2') - self.assertEqual(sorted((await sftp.listdir('dir'))), + self.assertEqual(sorted(await sftp.listdir('dir')), ['.', '..', 'file1', 'file2']) finally: remove('dir') @@ -1599,7 +1965,7 @@ async def test_listdir_v4(self, sftp): os.mkdir('dir') self._create_file('dir/file1') self._create_file('dir/file2') - self.assertEqual(sorted((await sftp.listdir('dir'))), + self.assertEqual(sorted(await sftp.listdir('dir')), ['.', '..', 'file1', 'file2']) finally: remove('dir') @@ -1965,6 +2331,10 @@ async def test_open_read_bytes(self, sftp): f = await sftp.open('file', 'rb') self.assertEqual((await f.read()), b'xxx') + + await f.seek(0) + self.assertEqual([result async for result in + await f.read_parallel()], [(0, b'xxx')]) finally: if f: # pragma: no branch await f.close() @@ -1982,6 +2352,8 @@ async def test_open_read_offset_size(self, sftp): f = await sftp.open('file') self.assertEqual((await f.read(4, 2)), 'xxyy') + self.assertEqual([result async for result in + await f.read_parallel(4, 2)], [(2, b'xxyy')]) finally: if f: # pragma: no branch await f.close() @@ -2015,7 +2387,24 @@ async def test_open_read_parallel(self, sftp): self._create_file('file', 40*1024*'\0') f = await sftp.open('file') - self.assertEqual(len((await f.read(64*1024))), 40*1024) + self.assertEqual(len(await f.read(64*1024)), 40*1024) + finally: + if f: # pragma: no branch + await f.close() + + remove('file') + + @sftp_test + async def test_open_read_max_requests(self, sftp): + """Test reading data from a file with max requests set""" + + f = None + + try: + self._create_file('file', 16*1024*1024*'\0') + + f = await sftp.open('file', max_requests=4) + self.assertEqual(len(await f.read()), 16*1024*1024) finally: if f: # pragma: no branch await f.close() @@ -2032,10 +2421,11 @@ async def _test_read_out_of_order(self, sftp): f = None try: - self._create_file('file', 4*1024*1024*'\0') + random_data = os.urandom(12*1024*1024) + self._create_file('file', random_data) - async with sftp.open('file') as f: - await f.read() + async with sftp.open('file', 'rb') as f: + self.assertEqual(await f.read(), random_data) finally: remove('file') @@ -2228,6 +2618,8 @@ async def test_open_append(self, sftp): f = await sftp.open('file', 'a+') await f.write('yyy') self.assertEqual((await f.read()), '') + self.assertEqual([result async for result in + await f.read_parallel()], []) await f.close() with open('file') as localf: @@ -2786,8 +3178,8 @@ async def test_file_utime_v4(self, sftp): self.assertEqual(stat_result.st_mtime_ns, 2250000000) self.assertEqual((await sftp.getatime('file')), 1.0) self.assertEqual((await sftp.getatime_ns('file')), 1000000000) - self.assertIsNotNone((await sftp.getcrtime('file'))) - self.assertIsNotNone((await sftp.getcrtime_ns('file'))) + self.assertIsNotNone(await sftp.getcrtime('file')) + self.assertIsNotNone(await sftp.getcrtime_ns('file')) self.assertEqual((await sftp.getmtime('file')), 2.25) self.assertEqual((await sftp.getmtime_ns('file')), 2250000000) @@ -2800,8 +3192,8 @@ async def test_file_utime_v4(self, sftp): self.assertEqual(stat_result.st_mtime_ns, 4750000000) self.assertEqual((await sftp.getatime('file')), 3.5) self.assertEqual((await sftp.getatime_ns('file')), 3500000000) - self.assertIsNotNone((await sftp.getcrtime('file'))) - self.assertIsNotNone((await sftp.getcrtime_ns('file'))) + self.assertIsNotNone(await sftp.getcrtime('file')) + self.assertIsNotNone(await sftp.getcrtime_ns('file')) self.assertEqual((await sftp.getmtime('file')), 4.75) self.assertEqual((await sftp.getmtime_ns('file')), 4750000000) finally: @@ -2881,7 +3273,7 @@ async def test_file_sync(self, sftp): try: f = await sftp.open('file', 'w') - self.assertIsNone((await f.fsync())) + self.assertIsNone(await f.fsync()) finally: if f: # pragma: no branch await f.close() @@ -2959,6 +3351,9 @@ async def _return_invalid_handle(self, path, pflags, attrs): _return_invalid_handle): f = await sftp.open('file') + with self.assertRaises(SFTPFailure): + _ = [_ async for _ in f.request_ranges(0, 0)] + with self.assertRaises(SFTPFailure): await f.read() @@ -2981,6 +3376,9 @@ async def _return_invalid_handle(self, path, pflags, attrs): with self.assertRaises(SFTPFailure): await f.fsync() + with self.assertRaises(SFTPFailure): + await sftp.remote_copy(f, f) + with self.assertRaises(SFTPFailure): await f.close() @@ -3021,6 +3419,9 @@ async def test_closed_file(self, sftp): with self.assertRaises(ValueError): await f.read() + with self.assertRaises(ValueError): + await f.read_parallel() + with self.assertRaises(ValueError): await f.write('') @@ -3403,7 +3804,7 @@ async def _short_ok_response(self, pkttype, pktid, packet): with patch('asyncssh.sftp.SFTPServerHandler._process_packet', _short_ok_response): - self.assertIsNone((await sftp.mkdir('dir'))) + self.assertIsNone(await sftp.mkdir('dir')) @sftp_test async def test_malformed_realpath_response(self, sftp): @@ -3469,6 +3870,10 @@ async def _unsupported_extensions(self, sftp): with self.assertRaises(SFTPOpUnsupported): await f.fsync() + + with self.assertRaises(SFTPOpUnsupported): + await sftp.setstat('file1', SFTPAttrs(), + follow_symlinks=False) finally: if f: # pragma: no branch await f.close() @@ -3512,6 +3917,26 @@ async def _unsupported_extensions_v6(self, sftp): # pylint: disable=no-value-for-parameter _unsupported_extensions_v6(self) + @asynctest + async def test_zero_limits(self): + """Test sending a server limits response with zero read/write length""" + + async def _send_zero_read_write_len(self, packet): + """Send a server limits response with zero read/write length""" + + # pylint: disable=unused-argument + + return SFTPLimits(0, 0, 0, 0) + + with patch.dict('asyncssh.sftp.SFTPServerHandler._packet_handlers', + {b'limits@openssh.com': _send_zero_read_write_len}): + async with self.connect() as conn: + async with conn.start_sftp_client() as sftp: + self.assertEqual(sftp.limits.max_read_len, + SAFE_SFTP_READ_LEN) + self.assertEqual(sftp.limits.max_write_len, + SAFE_SFTP_WRITE_LEN) + def test_write_close(self): """Test session cleanup in the middle of a write request""" @@ -3759,6 +4184,35 @@ async def test_log_formatting(self, sftp): asyncssh.set_sftp_log_level('WARNING') + @sftp_test + async def test_makedirs_no_parent_perms(self, sftp): + """Test creating a directory path without perms for a parent dir""" + + orig_mkdir = sftp.mkdir + + def _mkdir(path, *args, **kwargs): + if path == b'/': + raise SFTPPermissionDenied('') + return orig_mkdir(path, *args, **kwargs) + + try: + root = os.path.abspath(os.getcwd()) + with patch.object(sftp, 'mkdir', _mkdir): + await sftp.makedirs(os.path.join(root, 'dir/dir1')) + self.assertTrue(os.path.isdir(os.path.join(root, 'dir/dir1'))) + finally: + remove('dir') + + @sftp_test + async def test_makedirs_no_perms(self, sftp): + """Test creating a directory path without perms for all parents""" + + root = os.path.abspath(os.getcwd()) + + with patch.object(sftp, 'mkdir', side_effect=SFTPPermissionDenied('')): + with self.assertRaises(SFTPPermissionDenied): + await sftp.makedirs(os.path.join(root, 'dir/dir1')) + class _TestSFTPCallable(_CheckSFTP): """Unit tests for AsyncSSH SFTP factory being a callable""" @@ -3767,10 +4221,33 @@ class _TestSFTPCallable(_CheckSFTP): async def start_server(cls): """Start an SFTP server using a callable""" - def sftp_factory(conn): + def sftp_factory(chan): """Return an SFTP server""" - return SFTPServer(conn) + return SFTPServer(chan) + + return await cls.create_server(sftp_factory=sftp_factory) + + @sftp_test + async def test_stat(self, sftp): + """Test getting attributes on a file""" + + # pylint: disable=no-self-use + + await sftp.stat('.') + + +class _TestSFTPCoroutine(_CheckSFTP): + """Unit tests for AsyncSSH SFTP factory being a coroutine""" + + @classmethod + async def start_server(cls): + """Start an SFTP server using a coroutine""" + + async def sftp_factory(chan): + """Return an SFTP server""" + + return SFTPServer(chan) return await cls.create_server(sftp_factory=sftp_factory) @@ -3830,7 +4307,7 @@ async def test_chroot_glob(self, sftp): try: self._create_file('chroot/file1') self._create_file('chroot/file2') - self.assertEqual(sorted((await sftp.glob('/file*'))), + self.assertEqual(sorted(await sftp.glob('/file*')), ['/file1', '/file2']) finally: remove('chroot/file1 chroot/file2') @@ -3877,6 +4354,7 @@ async def test_chroot_realpath_v6(self, sftp): with self.assertRaises(SFTPInvalidParameter): await sftp.realpath('.', check=99) + @sftp_test async def test_getcwd_and_chdir(self, sftp): """Test changing directory on an SFTP server with a changed root""" @@ -4051,30 +4529,39 @@ async def start_server(cls): return await cls.create_server(sftp_factory=_IOErrorSFTPServer) - @sftp_test - async def test_put_error(self, sftp): - """Test error when putting a file to an SFTP server""" + def test_copy_error(self): + """Test error when copying a file on an SFTP server""" - for method in ('get', 'put', 'copy'): - with self.subTest(method=method): - try: - self._create_file('src', 4*1024*1024*'\0') + @sftp_test + async def _test_copy_error(self, sftp): + """Test error when copying a file on an SFTP server""" - with self.assertRaises(SFTPFailure): - await getattr(sftp, method)('src', 'dst') - finally: - remove('src dst') + try: + self._create_file('src', 8*1024*1024*'\0') + + with self.assertRaises(SFTPFailure): + await sftp.copy('src', 'dst') + finally: + remove('src dst') + + with patch('asyncssh.sftp.SFTPServerHandler._extensions', []): + # pylint: disable=no-value-for-parameter + _test_copy_error(self) @sftp_test async def test_read_error(self, sftp): """Test error when reading a file on an SFTP server""" try: - self._create_file('file', 4*1024*1024*'\0') + self._create_file('file', 8*1024*1024*'\0') - with self.assertRaises(SFTPFailure): - async with sftp.open('file') as f: - await f.read(4*1024*1024) + async with sftp.open('file') as f: + with self.assertRaises(SFTPFailure): + await f.read(8*1024*1024) + + with self.assertRaises(SFTPFailure): + async for _ in await f.read_parallel(8*1024*1024): + pass finally: remove('file') @@ -4085,7 +4572,7 @@ async def test_write_error(self, sftp): try: with self.assertRaises(SFTPFailure): async with sftp.open('file', 'w') as f: - await f.write(4*1024*1024*'\0') + await f.write(8*1024*1024*'\0') finally: remove('file') @@ -4108,10 +4595,10 @@ async def test_read(self, sftp): data = os.urandom(65536) self._create_file('file', data) - async with sftp.open('file', 'rb') as f: - result = await f.read(32768, 16384) + async with sftp.open('file', 'rb', block_size=16384) as f: + result = await f.read(65536, 16384) - self.assertEqual(result, data[16384:49152]) + self.assertEqual(result, data[16384:]) finally: remove('file') @@ -4120,7 +4607,7 @@ async def test_get(self, sftp): """Test getting a file from an SFTP server with a small block size""" try: - data = os.urandom(65536) + data = os.urandom(8*1024*1024) self._create_file('src', data) await sftp.get('src', 'dst') self._check_file('src', 'dst') @@ -4142,10 +4629,10 @@ async def test_get(self, sftp): """Test getting a file from an SFTP server truncated during the copy""" try: - self._create_file('src', 65536*'\0') + self._create_file('src', 8*1024*1024*'\0') with self.assertRaises(SFTPFailure): - await sftp.get('src', 'dst') + await sftp.get('src', 'dst', sparse=False) finally: remove('src dst') @@ -4262,7 +4749,7 @@ async def start_server(cls): async def test_large_listdir(self, sftp): """Test large listdir result""" - self.assertEqual(len((await sftp.readdir('/'))), 100000) + self.assertEqual(len(await sftp.readdir('/')), 100000) @unittest.skipIf(sys.platform == 'win32', 'skip statvfs tests on Windows') @@ -4298,7 +4785,7 @@ def _check_statvfs(self, sftp_statvfs): async def test_statvfs(self, sftp): """Test getting attributes on a filesystem""" - self._check_statvfs((await sftp.statvfs('.'))) + self._check_statvfs(await sftp.statvfs('.')) @sftp_test async def test_file_statvfs(self, sftp): @@ -4310,7 +4797,7 @@ async def test_file_statvfs(self, sftp): self._create_file('file') f = await sftp.open('file') - self._check_statvfs((await f.statvfs())) + self._check_statvfs(await f.statvfs()) finally: if f: # pragma: no branch await f.close() @@ -4548,6 +5035,25 @@ def _report_progress(_srcpath, _dstpath, bytes_copied, _total_bytes): self.assertEqual(len(reports), (size // 8192) + 1) self.assertEqual(reports[-1], size) + async def _check_cancel(self, src, dst): + """Check cancelling a file transfer over SCP""" + + def _cancel(_srcpath, _dstpath, _bytes_copied, _total_bytes): + """Cancel transfer""" + + task.cancel() + + try: + self._create_file('src', 1024*8192 * 'a') + + coro = scp(src, dst, block_size=8192, progress_handler=_cancel) + task = asyncio.create_task(coro) + + with self.assertRaises(asyncio.CancelledError): + await task + finally: + remove('src dst') + class _TestSCP(_CheckSCP): """Unit tests for AsyncSSH SCP client and server""" @@ -4567,6 +5073,12 @@ async def test_get_progress(self): await self._check_progress((self._scp_server, 'src'), 'dst') + @asynctest + async def test_get_cancel(self): + """Test cancelling a get of a file over SCP""" + + await self._check_cancel((self._scp_server, 'src'), 'dst') + @asynctest async def test_get_preserve(self): """Test getting a file with preserved attributes over SCP""" @@ -4696,6 +5208,12 @@ async def test_put_progress(self): await self._check_progress('src', (self._scp_server, 'dst')) + @asynctest + async def test_put_cancel(self): + """Test cancelling a put of a file over SCP""" + + await self._check_cancel('src', (self._scp_server, 'dst')) + @asynctest async def test_put_preserve(self): """Test putting a file with preserved attributes over SCP""" @@ -4781,15 +5299,15 @@ async def test_put_read_error(self): """Test read errors when putting a file over SCP""" async def _read_error(self, size, offset): - """Return an error for reads past 64 KB in a file""" + """Return an error for reads past 4 MB in a file""" - if offset >= 65536: + if offset >= 4*1024*1024: raise OSError(errno.EIO, 'I/O error') else: return await orig_read(self, size, offset) try: - self._create_file('src', 128*1024*'\0') + self._create_file('src', 8*1024*1024*'\0') orig_read = LocalFile.read @@ -4804,15 +5322,15 @@ async def test_put_read_early_eof(self): """Test getting early EOF when putting a file over SCP""" async def _read_early_eof(self, size, offset): - """Return an early EOF for reads past 64 KB in a file""" + """Return an early EOF for reads past 4 MB in a file""" - if offset >= 65536: + if offset >= 4*1024*1024: return b'' else: return await orig_read(self, size, offset) try: - self._create_file('src', 128*1024*'\0') + self._create_file('src', 8*1024*1024*'\0') orig_read = LocalFile.read @@ -4851,6 +5369,13 @@ async def test_copy_progress(self): await self._check_progress((self._scp_server, 'src'), (self._scp_server, 'dst')) + @asynctest + async def test_copy_cancel(self): + """Test cancelling a copy of a file over SCP""" + + await self._check_cancel((self._scp_server, 'src'), + (self._scp_server, 'dst')) + @asynctest async def test_copy_preserve(self): """Test copying a file with preserved attributes between hosts""" @@ -4990,7 +5515,7 @@ async def test_source_bytes(self): """Test passing a byte string to SCP""" with self.assertRaises(OSError): - await scp('\xff:xxx'.encode('utf-8'), '.') + await scp('\xff:xxx'.encode(), '.') @asynctest async def test_source_open_connection(self): @@ -5072,6 +5597,22 @@ async def start_server(cls): allow_scp=True) +class _TestSCPCoroutine(_TestSCP): + """Unit test for AsyncSSH SCP with the SFTP factory being a coroutine""" + + @classmethod + async def start_server(cls): + """Start an SFTP server with async callbacks""" + + async def sftp_factory(chan): + """Return an SFTP server""" + + return SFTPServer(chan) + + return await cls.create_server(sftp_factory=sftp_factory, + allow_scp=True) + + class _TestSCPAttrs(_CheckSCP): """Unit test for SCP with SFTP server returning SFTPAttrs""" @@ -5088,7 +5629,7 @@ async def test_get(self): try: self._create_file('src') - await scp((self._scp_server, 'src'), 'dst') + await scp((self._scp_server, 'src*'), 'dst') self._check_file('src', 'dst') finally: remove('src dst') @@ -5138,7 +5679,7 @@ async def test_put_error(self): """Test error when putting a file over SCP""" try: - self._create_file('src', 4*1024*1024*'\0') + self._create_file('src', 8*1024*1024*'\0') with self.assertRaises(SFTPFailure): await scp('src', (self._scp_server, 'dst')) @@ -5150,7 +5691,7 @@ async def test_copy_error(self): """Test error when copying a file over SCP""" try: - self._create_file('src', 4*1024*1024*'\0') + self._create_file('src', 8*1024*1024*'\0') with self.assertRaises(SFTPFailure): await scp((self._scp_server, 'src'), diff --git a/tests/test_sk.py b/tests/test_sk.py index 94e94df..e8cc2f8 100644 --- a/tests/test_sk.py +++ b/tests/test_sk.py @@ -26,7 +26,7 @@ from .server import ServerTestCase from .sk_stub import sk_available, stub_sk, unstub_sk, patch_sk, sk_error -from .util import asynctest +from .util import asynctest, get_test_key class _CheckSKAuth(ServerTestCase): @@ -37,6 +37,7 @@ class _CheckSKAuth(ServerTestCase): _sk_resident = False _sk_touch_required = True _sk_auth_touch_required = True + _sk_use_webauthn = False _sk_cert = False _sk_host = False @@ -44,10 +45,11 @@ class _CheckSKAuth(ServerTestCase): async def start_server(cls): """Start an SSH server which supports security key authentication""" - cls.addClassCleanup(unstub_sk, *stub_sk(cls._sk_devs)) + cls.addClassCleanup(unstub_sk, *stub_sk(cls._sk_devs, + cls._sk_use_webauthn)) - cls._privkey = asyncssh.generate_private_key( - cls._sk_alg, resident=cls._sk_resident, + cls._privkey = get_test_key( + cls._sk_alg, cls._sk_use_webauthn, resident=cls._sk_resident, touch_required=cls._sk_touch_required) if cls._sk_host: @@ -123,8 +125,9 @@ async def test_auth_ctap1_error(self): """Test security key returning a CTAP 1 error""" with sk_error('err'): - with self.assertRaises(ValueError): - await self.connect(username='ckey', client_keys=[self._privkey]) + with self.assertRaises(asyncssh.PermissionDenied): + await self.connect(username='ckey', + client_keys=[self._privkey]) @unittest.skipUnless(sk_available, 'security key support not available') @@ -144,7 +147,7 @@ async def test_auth(self): async def test_enroll_without_pin(self): """Test generating key without a PIN""" - key = asyncssh.generate_private_key('sk-ssh-ed25519@openssh.com') + key = get_test_key('sk-ssh-ed25519@openssh.com') self.assertIsNotNone(key) @@ -152,8 +155,7 @@ async def test_enroll_without_pin(self): async def test_enroll_with_pin(self): """Test generating key with a PIN""" - key = asyncssh.generate_private_key('sk-ssh-ed25519@openssh.com', - pin=b'123456') + key = get_test_key('sk-ssh-ed25519@openssh.com', pin=b'123456') self.assertIsNotNone(key) @@ -170,8 +172,9 @@ async def test_auth_ctap2_error(self): """Test security key returning a CTAP 2 error""" with sk_error('err'): - with self.assertRaises(ValueError): - await self.connect(username='ckey', client_keys=[self._privkey]) + with self.assertRaises(asyncssh.PermissionDenied): + await self.connect(username='ckey', + client_keys=[self._privkey]) @asynctest async def test_enroll_pin_invalid(self): @@ -191,6 +194,20 @@ async def test_enroll_pin_required(self): asyncssh.generate_private_key('sk-ssh-ed25519@openssh.com') +@unittest.skipUnless(sk_available, 'security key support not available') +class _TestSKAuthWebAuthN(_CheckSKAuth): + """Unit tests for security key authentication with WebAuthN""" + + _sk_alg = 'sk-ecdsa-sha2-nistp256@openssh.com' + _sk_use_webauthn = True + + @asynctest + async def test_auth(self): + """Test authenticating with the Windows WebAuthN API""" + + async with self.connect(username='ckey', client_keys=[self._privkey]): + pass + @unittest.skipUnless(sk_available, 'security key support not available') class _TestSKAuthMultipleKeys(_CheckSKAuth): """Unit tests for security key authentication with multiple keys""" @@ -202,8 +219,9 @@ async def test_auth_cred_not_found(self): """Test authenticating with security credential not found""" with sk_error('nocred'): - with self.assertRaises(ValueError): - await self.connect(username='ckey', client_keys=[self._privkey]) + with self.assertRaises(asyncssh.PermissionDenied): + await self.connect(username='ckey', + client_keys=[self._privkey]) @unittest.skipUnless(sk_available, 'security key support not available') @@ -256,7 +274,7 @@ async def test_load_resident_ctap2_error(self): """Test getting resident keys returning a CTAP 2 error""" with sk_error('err'): - with self.assertRaises(ValueError): + with self.assertRaises(asyncssh.KeyImportError): asyncssh.load_resident_keys(b'123456') @asynctest diff --git a/tests/test_stream.py b/tests/test_stream.py index fb1ebaa..83884d3 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -21,6 +21,7 @@ """Unit tests for AsyncSSH stream API""" import asyncio +import re import asyncssh @@ -48,11 +49,23 @@ async def _begin_session(self, stdin, stdout, stderr): await stdin.read(1) stdout.write('\n') elif action == 'disconnect': - stdout.write((await stdin.read(1))) + stdout.write(await stdin.read(1)) raise asyncssh.ConnectionLost('Connection lost') elif action == 'custom_disconnect': await stdin.read(1) raise asyncssh.DisconnectError(99, 'Disconnect') + elif action == 'partial': + try: + await stdin.readexactly(10) + except asyncio.IncompleteReadError as exc: + stdout.write(exc.partial) + + try: + await stdin.read() + except asyncssh.TerminalSizeChanged: + pass + + stdout.write(await stdin.readexactly(5)) else: stdin.channel.exit(255) @@ -82,6 +95,20 @@ def session_requested(self): return False +class _UpstreamForwardingServer(Server): + """Server for testing forwarding between SSH connections""" + + def __init__(self, upstream_conn): + super().__init__() + + self._upstream_conn = upstream_conn + + def session_requested(self): + """Handle a request to create a new session""" + + return self._upstream_conn + + class _TestStream(ServerTestCase): """Unit tests for AsyncSSH stream API""" @@ -116,9 +143,9 @@ async def _check_session(self, conn, large_block=False): self.assertEqual(data, stdout_data) self.assertEqual(data, stderr_data) - await stdin.channel.wait_closed() await stdin.drain() stdin.close() + await stdin.channel.wait_closed() @asynctest @@ -128,6 +155,24 @@ async def test_shell(self): async with self.connect() as conn: await self._check_session(conn) + @asynctest + async def test_upstream_shell(self): + """Test upstream forwarding of a shell request""" + + def upstream_server(): + """Return a server capable of forwarding between SSH connections""" + + return _UpstreamForwardingServer(upstream_conn) + + async with self.connect() as upstream_conn: + upstream_listener = await self.create_server(upstream_server) + upstream_port = upstream_listener.get_port() + + async with self.connect('127.0.0.1', upstream_port) as conn: + await self._check_session(conn) + + upstream_listener.close() + @asynctest async def test_shell_failure(self): """Test failure to start a shell""" @@ -233,6 +278,19 @@ async def test_readline_exception(self): with self.assertRaises(asyncssh.ConnectionLost): await stdout.readline() + @asynctest + async def test_readexactly_partial_exception(self): + """Test readexactly returning partial data before an exception""" + + async with self.connect() as conn: + stdin, stdout, _ = await conn.open_session('partial') + + stdin.write('abcde') + stdout.channel.change_terminal_size(80, 24) + stdin.write('fghij') + + self.assertEqual((await stdout.read()), 'abcdefghij') + @asynctest async def test_custom_disconnect(self): """Test receiving a custom disconnect message""" @@ -366,6 +424,27 @@ async def test_readuntil_empty_separator(self): stdin.close() + @asynctest + async def test_readuntil_regex(self): + """Test readuntil with a regex pattern""" + + async with self.connect() as conn: + stdin, stdout, _ = await conn.open_session() + stdin.write("hello world\nhello world") + output = await stdout.readuntil( + re.compile('hello world'), len('hello world') + ) + self.assertEqual(output, "hello world") + + output = await stdout.readuntil( + re.compile('hello world'), len('hello world') + ) + self.assertEqual(output, "\nhello world") + + stdin.close() + + await conn.wait_closed() + @asynctest async def test_abort(self): """Test abort on a channel""" diff --git a/tests/test_tuntap.py b/tests/test_tuntap.py new file mode 100644 index 0000000..5e90e9f --- /dev/null +++ b/tests/test_tuntap.py @@ -0,0 +1,757 @@ +# Copyright (c) 2024 by Ron Frederick and others. +# +# This program and the accompanying materials are made available under +# the terms of the Eclipse Public License v2.0 which accompanies this +# distribution and is available at: +# +# http://www.eclipse.org/legal/epl-2.0/ +# +# This program may also be made available under the following secondary +# licenses when the conditions for such availability set forth in the +# Eclipse Public License v2.0 are satisfied: +# +# GNU General Public License, Version 2.0, or any later versions of +# that license +# +# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later +# +# Contributors: +# Ron Frederick - initial implementation, API, and documentation + +"""Unit tests for AsyncSSH TUN/TAP support""" + +import asyncio +import builtins +import errno +import socket +import struct +import sys + +from unittest import skipIf, skipUnless +from unittest.mock import patch + +import asyncssh +from asyncssh.tuntap import IFF_FMT, LINUX_IFF_TUN + +from .server import Server, ServerTestCase +from .util import asynctest + +if sys.platform != 'win32': # pragma: no branch + import fcntl + + +_orig_funcs = {} + + +class _TunTapSocketMock: + """TunTap socket mock""" + + def ioctl(self, request, arg): + """Ignore ioctl requests to bring interface up""" + + # pylint: disable=no-self-use,unused-argument + + return arg + + def close(self): + """Close this mock""" + + # pylint: disable=no-self-use + + +class _TunTapMock: + """Common TUN/TAP mock""" + + _from_intf = {} + + def __init__(self, interface=None): + if interface in self._from_intf: + raise OSError(errno.EBUSY, 'Device busy') + + self._loop = asyncio.get_event_loop() + + self._sock1, self._sock2 = socket.socketpair(type=socket.SOCK_DGRAM) + self._sock2.setblocking(False) + + self._interface = interface + + if interface: + self._from_intf[interface] = self + + @classmethod + def lookup_intf(cls, interface): + """Look up mock by interface""" + + return cls._from_intf[interface] + + def fileno(self): + """Return the fileno of sock1""" + + return self._sock1.fileno() + + def setblocking(self, blocking): + """Set blocking mode on the socket""" + + self._sock1.setblocking(blocking) + + async def get_packets(self, count): + """Get packets written to the TUN/TAP""" + + return [await self._loop.sock_recv(self._sock2, 65536) + for _ in range(count)] + + def put_packets(self, packets): + """Put packets for the TUN/TAP to read""" + + for packet in packets: + self._sock2.send(packet) + + def read(self, size=-1): + """Read a packet""" + + return self._sock1.recv(size) + + def write(self, packet): + """Write a packet""" + + return self._sock1.send(packet) + + def close(self): + """Close this mock""" + + self._from_intf.pop(self._interface, None) + + self._sock2.send(b'') + self._sock1.close() + self._sock2.close() + + +class _TunTapOSXMock(_TunTapMock): + """TunTapOSX mock""" + + disable = False + + def __init__(self, name): + if self.disable: + raise OSError(errno.ENOENT, 'No such device') + + interface = name[5:] + + if int(interface[3:]) >= 16: + raise OSError(errno.ENOENT, 'No such device') + + super().__init__(interface) + + +class _DarwinUTunMock(_TunTapMock): + """Darwin UTun mock""" + + _AF_INET_PREFIX = socket.AF_INET.to_bytes(4, 'big') + + def __init__(self): + super().__init__() + + self._unit = None + + def ioctl(self, request, arg): + """Respond to DARWIN_CTLIOCGINFO request""" + + # pylint: disable=no-self-use,unused-argument + + return arg + + def connect(self, addr): + """Connect to requested unit""" + + _, unit = addr + + if unit == 0: + for unit in range(16): + interface = f'utun{unit}' + + if interface not in self._from_intf: + break + else: + raise OSError(errno.EBUSY, 'No utun devices available') + elif unit <= 16: + unit -= 1 + interface = f'utun{unit}' + + if interface in self._from_intf: + raise OSError(errno.EBUSY, 'Device busy') + else: + raise OSError(errno.ENOENT, 'No such device') + + self._unit = unit + self._interface = interface + self._from_intf[interface] = self + + return 0 + + def getpeername(self): + """Return utun unit""" + + return (0, self._unit + 1) + + def send(self, packet): + """Send a packet""" + + return super().write(packet[4:]) + + def recv(self, size): + """Receive a packet""" + + return self._AF_INET_PREFIX + self.read(size) + + +class _LinuxMock(_TunTapMock): + """Linux TUN/TAP mock""" + + def __init__(self): + super().__init__() + + self._sock1.setblocking(False) + + def ioctl(self, request, arg): + """Respond to LINUX_TUNSETIFF request""" + + # pylint: disable=unused-argument + + name, flags = struct.unpack(IFF_FMT, arg) + + if name[0] == 0: + prefix = 'tun' if flags & LINUX_IFF_TUN else 'tap' + + for unit in range(16): + interface = f'{prefix}{unit}' + + if interface not in self._from_intf: + break + else: + self.close() + raise OSError(errno.EBUSY, 'No tun devices available') + + arg = struct.pack(IFF_FMT, interface.encode(), flags) + else: + interface = name.strip(b'\0').decode() + unit = int(interface[3:]) + + if unit >= 16: + raise OSError(errno.ENOENT, 'No such device') + + self._interface = interface + self._from_intf[interface] = self + + return arg + + def read(self, size=-1): + """Read a packet""" + + try: + return super().read(size) + except BlockingIOError: + return None + + +def _open(name, mode, *args, **kwargs): + """Mock file open""" + + name = str(name) + + if name.startswith('/dev/tun') or name.startswith('/dev/tap'): + return _TunTapOSXMock(name) + elif name == '/dev/net/tun': + return _LinuxMock() + else: + return _orig_funcs['open'](name, mode, *args, **kwargs) + + +# pylint: disable=redefined-builtin +def _socket(family=socket.AF_INET, type=socket.SOCK_STREAM, + proto=0, fileno=None): + """Mock socket creation""" + + if hasattr(socket, 'PF_SYSTEM') and family == socket.PF_SYSTEM and \ + type == socket.SOCK_DGRAM and proto == socket.SYSPROTO_CONTROL: + return _DarwinUTunMock() + elif family == socket.AF_INET and type == socket.SOCK_DGRAM: + return _TunTapSocketMock() + else: + return _orig_funcs['socket'](family, type, proto, fileno) + + +def _ioctl(file, request, arg): + """Mock ioctl""" + + if isinstance(file, (_DarwinUTunMock, _LinuxMock, _TunTapSocketMock)): + return file.ioctl(request, arg) + else: # pragma: no cover + return _orig_funcs['ioctl'](file, request, arg) + + +async def get_packets(interface, count): + """Return TUN/TAP packets written""" + + return await _TunTapMock.lookup_intf(interface).get_packets(count) + + +def put_packets(interface, packets): + """Feed packets to a TUN/TAP mock""" + + _TunTapMock.lookup_intf(interface).put_packets(packets) + + +def patch_tuntap(cls): + """Decorator to stub out TUN/TAP functions""" + + _orig_funcs['open'] = builtins.open + _orig_funcs['socket'] = socket.socket + + cls = patch('builtins.open', _open)(cls) + cls = patch('socket.socket', _socket)(cls) + + if sys.platform != 'win32': # pragma: no branch + _orig_funcs['ioctl'] = fcntl.ioctl + cls = patch('fcntl.ioctl', _ioctl)(cls) + + return cls + + +class _EchoSession(asyncssh.SSHTunTapSession): + """Echo packets on a TUN session""" + + def __init__(self): + self._chan = None + + def connection_made(self, chan): + """Handle session open""" + + self._chan = chan + + def data_received(self, data, datatype): + """Handle data from the channel""" + + self._chan.write(data) + + def eof_received(self): + """Handle EOF from the channel""" + + self._chan.write_eof() + + +class _TunTapServer(Server): + """Server for testing TUN/TAP functions""" + + async def _echo_handler(self, reader, writer): + """Echo packets on a TUN session""" + + try: + async for packet in reader: + writer.write(packet) + finally: + writer.close() + + def tun_requested(self, unit): + """Handle TUN requests""" + + if unit is None or unit <= 32: + return True + elif unit == 33: + return _EchoSession() + elif unit == 34: + return (self._conn.create_tuntap_channel(), _EchoSession()) + elif unit == 35: + return self._echo_handler + else: + return False + + def tap_requested(self, unit): + """Handle TAP requests""" + + if unit == 33: + return _EchoSession() + else: + return True + + +class _UpstreamForwardingServer(Server): + """Server for testing forwarding between SSH connections""" + + def __init__(self, upstream_conn): + super().__init__() + + self._upstream_conn = upstream_conn + + def tun_requested(self, unit): + """Handle a request to create a new layer 3 tunnel""" + + return self._upstream_conn + + def tap_requested(self, unit): + """Handle a request to create a new layer 2 tunnel""" + + return self._upstream_conn + + +@skipIf(sys.platform == 'win32', 'skip TUN/TAP tests on Windows') +@patch_tuntap +class _TestTunTap(ServerTestCase): + """Unit tests for TUN/TAP functions""" + + @classmethod + async def start_server(cls): + """Start an SSH server to connect to""" + + return await cls.create_server( + _TunTapServer, authorized_client_keys='authorized_keys') + + async def _check_tuntap(self, coro, interface): + """Check sending data on a TUN or TAP channel""" + + reader, writer = await coro + + try: + packets = [b'123', b'456', b'789'] + count = len(packets) + + for packet in packets: + writer.write(packet) + + self.assertEqual((await get_packets(interface, count)), packets) + + put_packets(interface, packets) + + for packet in packets: + self.assertEqual((await reader.read()), packet) + finally: + writer.close() + + async def _check_tuntap_forward(self, coro, remote_interface): + """Check sending data on a TUN or TAP channel""" + + async with coro as forw: + local_interface = forw.get_extra_info('interface') + + packets = [b'123', b'456', b'789'] + count = len(packets) + + put_packets(local_interface, packets) + + self.assertEqual((await get_packets(remote_interface, count)), + packets) + + put_packets(remote_interface, packets) + + self.assertEqual((await get_packets(local_interface, count)), + packets) + + async def _check_tuntap_echo(self, coro): + """Check echoing of packets on a TUN channel""" + + reader, writer = await coro + + try: + writer.write(b'123') + self.assertEqual((await reader.read()), b'123') + writer.write_eof() + self.assertEqual((await reader.read()), b'') + finally: + writer.close() + await writer.wait_closed() + + @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') + @asynctest + async def test_darwin_open_tun(self): + """Test sending packets on a layer 3 tunnel on macOS""" + + async with self.connect() as conn: + await self._check_tuntap(conn.open_tun(), 'tun0') + + @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') + @asynctest + async def test_darwin_open_tun_specific_unit(self): + """Test sending on a layer 3 tunnel with specific unit on macOS""" + + async with self.connect() as conn: + await self._check_tuntap(conn.open_tun(0), 'tun0') + + @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') + @asynctest + async def test_darwin_open_tun_error(self): + """Test returning an open error on a layer 3 tunnel on macOS""" + + with self.assertRaises(asyncssh.ChannelOpenError): + async with self.connect() as conn: + await conn.open_tun(32) + + @skipUnless(sys.platform == 'darwin', 'only run utun tests on macOS') + @asynctest + async def test_darwin_open_utun(self): + """Test sending packets on a layer 3 tunnel using UTun on macOS""" + + async with self.connect() as conn: + await self._check_tuntap(conn.open_tun(16), 'utun0') + + @skipUnless(sys.platform == 'darwin', 'only run utun tests on macOS') + @asynctest + async def test_darwin_failover_to_utun(self): + """Test failing over from TunTapOSX to UTun on macOS""" + + try: + _TunTapOSXMock.disable = True + + async with self.connect() as conn: + await self._check_tuntap(conn.open_tun(), 'utun0') + finally: + _TunTapOSXMock.disable = False + + @skipUnless(sys.platform == 'darwin', 'only run utun tests on macOS') + @asynctest + async def test_darwin_utun_in_use(self): + """Test UTun device already in use on macOS""" + + async with self.connect() as conn: + _, writer = await conn.open_tun(16) + + try: + with self.assertRaises(asyncssh.ChannelOpenError): + await conn.open_tun(16) + finally: + writer.close() + await writer.wait_closed() + + @skipUnless(sys.platform == 'darwin', 'only run utun tests on macOS') + @asynctest + async def test_darwin_utun_all_in_use(self): + """Test all UTun devices already in use on macOS""" + + async with self.connect() as conn: + writers = [] + + try: + for unit in range(32): + _, writer = await conn.open_tun(unit) + writers.append(writer) + + with self.assertRaises(asyncssh.ChannelOpenError): + await conn.open_tun() + finally: + for writer in writers: + writer.close() + await writer.wait_closed() + + @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') + @asynctest + async def test_darwin_open_tap(self): + """Test sending packets on a layer 2 tunnel on macOS""" + + async with self.connect() as conn: + await self._check_tuntap(conn.open_tap(), 'tap0') + + @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') + @asynctest + async def test_darwin_open_tap_unavailable(self): + """Test TunTapOSX not being available on macOS""" + + try: + _TunTapOSXMock.disable = True + + with self.assertRaises(asyncssh.ChannelOpenError): + async with self.connect() as conn: + await conn.open_tap() + finally: + _TunTapOSXMock.disable = False + + @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') + @asynctest + async def test_darwin_open_tap_error(self): + """Test sending packets on a layer 2 tunnel on macOS""" + + with self.assertRaises(asyncssh.ChannelOpenError): + async with self.connect() as conn: + await conn.open_tap(16) + + @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') + @asynctest + async def test_darwin_forward_tun(self): + """Test forwarding packets on a layer 3 tunnel on macOS""" + + async with self.connect() as conn: + await self._check_tuntap_forward(conn.forward_tun(), 'tun0') + + @skipUnless(sys.platform == 'darwin', 'only run utun tests on macOS') + @asynctest + async def test_darwin_forward_utun(self): + """Test forwarding packets on a layer 3 tunnel on macOS""" + + async with self.connect() as conn: + await self._check_tuntap_forward(conn.forward_tun(16, 17), 'utun1') + + @skipUnless(sys.platform == 'darwin', 'only run TapTunOSX tests on macOS') + @asynctest + async def test_darwin_forward_tap(self): + """Test forwarding packets on a layer 2 tunnel on macOS""" + + async with self.connect() as conn: + await self._check_tuntap_forward(conn.forward_tap(), 'tap0') + + @patch('sys.platform', 'linux') + @asynctest + async def test_linux_open_tun(self): + """Test sending packets on a layer 3 tunnel on Linux""" + + async with self.connect() as conn: + await self._check_tuntap(conn.open_tun(), 'tun0') + + @patch('sys.platform', 'linux') + @asynctest + async def test_linux_open_tun_specific_unit(self): + """Test sending on a layer 3 tunnel with specific unit on Linux""" + + async with self.connect() as conn: + await self._check_tuntap(conn.open_tun(), 'tun0') + + @patch('sys.platform', 'linux') + @asynctest + async def test_linux_open_tun_error(self): + """Test returning an open error on a layer 3 tunnel on Linux""" + + with self.assertRaises(asyncssh.ChannelOpenError): + async with self.connect() as conn: + await conn.open_tun(32) + + @patch('sys.platform', 'linux') + @asynctest + async def test_linux_open_tap(self): + """Test sending packets on a layer 2 tunnel on Linux""" + + async with self.connect() as conn: + await self._check_tuntap(conn.open_tap(), 'tap0') + + @patch('sys.platform', 'linux') + @asynctest + async def test_linux_forward_tun(self): + """Test forwarding packets on a layer 3 tunnel on Linux""" + + async with self.connect() as conn: + await self._check_tuntap_forward(conn.forward_tun(), 'tun0') + + @patch('sys.platform', 'linux') + @asynctest + async def test_linux_forward_tap(self): + """Test forwarding packets on a layer 2 tunnel on Linux""" + + async with self.connect() as conn: + await self._check_tuntap_forward(conn.forward_tap(), 'tap0') + + @patch('sys.platform', 'linux') + @asynctest + async def test_linux_all_in_use(self): + """Test all TUN devices already in use on Linux""" + + async with self.connect() as conn: + writers = [] + + try: + for unit in range(16): + _, writer = await conn.open_tun(unit) + writers.append(writer) + + with self.assertRaises(asyncssh.ChannelOpenError): + await conn.open_tun() + finally: + for writer in writers: + writer.close() + await writer.wait_closed() + + @patch('sys.platform', 'xxx') + @asynctest + async def test_unknown_platform(self): + """Test unknown platform""" + + async with self.connect() as conn: + with self.assertRaises(asyncssh.ChannelOpenError): + await conn.open_tun() + + @asynctest + async def test_open_tun_echo_session(self): + """Test an echo session on a layer 3 tunnel""" + + async with self.connect() as conn: + await self._check_tuntap_echo(conn.open_tun(33)) + + @asynctest + async def test_upstream_open_tun_echo_session(self): + """Test an echo session on a forwarded layer 3 tunnel""" + + def upstream_server(): + """Return a server capable of forwarding between SSH connections""" + + return _UpstreamForwardingServer(upstream_conn) + + async with self.connect() as upstream_conn: + upstream_listener = await self.create_server(upstream_server) + upstream_port = upstream_listener.get_port() + + async with self.connect('127.0.0.1', upstream_port) as conn: + await self._check_tuntap_echo(conn.open_tun(33)) + + upstream_listener.close() + + @asynctest + async def test_upstream_open_tap_echo_session(self): + """Test an echo session on a forwarded layer 2 tunnel""" + + def upstream_server(): + """Return a server capable of forwarding between SSH connections""" + + return _UpstreamForwardingServer(upstream_conn) + + async with self.connect() as upstream_conn: + upstream_listener = await self.create_server(upstream_server) + upstream_port = upstream_listener.get_port() + + async with self.connect('127.0.0.1', upstream_port) as conn: + await self._check_tuntap_echo(conn.open_tap(33)) + + upstream_listener.close() + + @asynctest + async def test_open_tun_echo_session_channel(self): + """Test an echo session & channel on a layer 3 tunnel""" + + async with self.connect() as conn: + await self._check_tuntap_echo(conn.open_tun(34)) + + @asynctest + async def test_open_tun_echo_handler(self): + """Test an echo stream handler on a layer 3 tunnel""" + + async with self.connect() as conn: + await self._check_tuntap_echo(conn.open_tun(35)) + + @asynctest + async def test_open_tun_denied(self): + """Test returning an open error on a layer 3 tunnel""" + + with self.assertRaises(asyncssh.ChannelOpenError): + async with self.connect() as conn: + await conn.open_tun(36) + + @asynctest + async def test_tun_forward_error(self): + """Test returning a forward error on a layer 3 tunnel""" + + with self.assertRaises(asyncssh.ChannelOpenError): + async with self.connect() as conn: + await conn.forward_tun(36) + + @asynctest + async def test_invalid_tun_mode(self): + """Test sending an invalid mode in a TUN/TAP request""" + + async with self.connect() as conn: + chan = conn.create_tuntap_channel() + + with self.assertRaises(asyncssh.ChannelOpenError): + await chan.open(asyncssh.SSHTunTapSession, 32, 0) diff --git a/tests/test_x509.py b/tests/test_x509.py index 65b9b26..d923dd6 100644 --- a/tests/test_x509.py +++ b/tests/test_x509.py @@ -25,9 +25,7 @@ from cryptography import x509 -import asyncssh - -from .util import x509_available +from .util import get_test_key, x509_available if x509_available: # pragma: no branch from asyncssh.crypto import X509Name, X509NamePattern @@ -43,7 +41,7 @@ class _TestX509(unittest.TestCase): @classmethod def setUpClass(cls): - cls._privkey = asyncssh.generate_private_key('ssh-rsa') + cls._privkey = get_test_key('ssh-rsa') cls._pubkey = cls._privkey.convert_to_public() cls._pubdata = cls._pubkey.export_public_key('pkcs8-der') @@ -81,7 +79,7 @@ def test_generate(self): cert = self.generate_certificate(purposes='secureShellClient') - self.assertEqual(cert.purposes, set((_purpose_secureShellClient,))) + self.assertEqual(cert.purposes, {_purpose_secureShellClient}) def test_generate_ca(self): """Test X.509 CA certificate generation""" diff --git a/tests/util.py b/tests/util.py index b947fe5..4b9a2f9 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2020 by Ron Frederick and others. +# Copyright (c) 2015-2022 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -24,6 +24,8 @@ import binascii import functools import os +import shutil +import socket import subprocess import sys import tempfile @@ -31,7 +33,15 @@ from unittest.mock import patch -# pylint: disable=unused-import +from asyncssh import set_default_skip_rsa_key_validation +from asyncssh.gss import gss_available +from asyncssh.logging import logger +from asyncssh.misc import ConnectionLost, SignalReceived +from asyncssh.packet import Byte, String, UInt32, UInt64 +from asyncssh.public_key import generate_private_key + + +# pylint: disable=ungrouped-imports, unused-import try: import bcrypt @@ -39,6 +49,8 @@ except ImportError: # pragma: no cover bcrypt_available = False +nc_available = bool(shutil.which('nc')) + try: import uvloop uvloop_available = True @@ -51,26 +63,25 @@ except ImportError: # pragma: no cover x509_available = False -# pylint: enable=unused-import - -from asyncssh.gss import gss_available -from asyncssh.logging import logger -from asyncssh.misc import ConnectionLost, SignalReceived -from asyncssh.packet import Byte, String, UInt32, UInt64 - +# pylint: enable=ungrouped-imports, unused-import # pylint: disable=no-member if hasattr(asyncio, 'all_tasks'): all_tasks = asyncio.all_tasks current_task = asyncio.current_task -else: +else: # pragma: no cover all_tasks = asyncio.Task.all_tasks current_task = asyncio.Task.current_task # pylint: enable=no-member +_test_keys = {} + +set_default_skip_rsa_key_validation(True) + + def asynctest(coro): """Decorator for async tests, for use with AsyncTestCase""" @@ -83,6 +94,31 @@ def async_wrapper(self, *args, **kwargs): return async_wrapper +def patch_getaddrinfo(cls): + """Decorator for patching socket.getaddrinfo""" + + # pylint: disable=redefined-builtin + + cls.orig_getaddrinfo = socket.getaddrinfo + + hosts = {'testhost.test': '', + 'testcname.test': 'cname.test', + 'cname.test': ''} + + def getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): + """Mock DNS lookup of server hostname""" + + # pylint: disable=unused-argument + + try: + return [(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, + hosts[host], ('127.0.0.1', port))] + except KeyError: + return cls.orig_getaddrinfo(host, port, family, type, proto, flags) + + return patch('socket.getaddrinfo', getaddrinfo)(cls) + + def patch_getnameinfo(cls): """Decorator for patching socket.getnameinfo""" @@ -96,6 +132,33 @@ def getnameinfo(sockaddr, flags): return patch('socket.getnameinfo', getnameinfo)(cls) +def patch_getnameinfo_error(cls): + """Decorator for patching socket.getnameinfo to raise an error""" + + def getnameinfo_error(sockaddr, flags): + """Mock failure of reverse DNS lookup of client address""" + + # pylint: disable=unused-argument + + raise socket.gaierror() + + return patch('socket.getnameinfo', getnameinfo_error)(cls) + + +def patch_extra_kex(cls): + """Decorator for skipping extra kex algs""" + + def skip_extra_kex_algs(self): + """Don't send extra key exchange algorithms""" + + # pylint: disable=unused-argument + + return [] + + return patch('asyncssh.connection.SSHConnection._get_extra_kex_algs', + skip_extra_kex_algs)(cls) + + def patch_gss(cls): """Decorator for patching GSSAPI classes""" @@ -157,6 +220,20 @@ def _encode_options(options): return b''.join((String(k) + String(v) for k, v in options.items())) +def get_test_key(alg_name, key_id=0, **kwargs): + """Generate or return a key with the requested parameters""" + + params = tuple((alg_name, key_id)) + tuple(kwargs.items()) + + try: + key = _test_keys[params] + except KeyError: + key = generate_private_key(alg_name, **kwargs) + _test_keys[params] = key + + return key + + def make_certificate(cert_version, cert_type, key, signing_key, principals, key_id='name', valid_after=0, valid_before=0xffffffffffffffff, options=None, @@ -164,7 +241,7 @@ def make_certificate(cert_version, cert_type, key, signing_key, principals, """Construct an SSH certificate""" keydata = key.encode_ssh_public() - principals = b''.join((String(p) for p in principals)) + principals = b''.join(String(p) for p in principals) options = _encode_options(options) if options else b'' extensions = _encode_options(extensions) if extensions else b'' signing_keydata = b''.join((String(signing_key.algorithm), @@ -179,7 +256,7 @@ def make_certificate(cert_version, cert_type, key, signing_key, principals, if bad_signature: data += String('') else: - data += String(signing_key.sign(data, signing_key.algorithm)) + data += String(signing_key.sign(data, signing_key.sig_algorithms[0])) return b''.join((cert_version.encode('ascii'), b' ', binascii.b2a_base64(data))) @@ -197,6 +274,15 @@ def run(cmd): raise +def try_remove(filename): + """Try to remove a file, ignoring errors""" + + try: + os.remove(filename) + except OSError: # pragma: no cover + pass + + class ConnectionStub: """Stub class used to replace an SSHConnection object""" @@ -262,7 +348,7 @@ async def _process_packets(self): self.connection_lost(data) break - self.process_packet(data) + await self.process_packet(data) def connection_lost(self, exc): """Handle the closing of a connection""" @@ -362,8 +448,6 @@ def setUpClass(cls): else: cls.loop = asyncio.new_event_loop() - asyncio.set_event_loop(cls.loop) - try: cls.loop.run_until_complete(cls.asyncSetUpClass()) except AttributeError: diff --git a/tox.ini b/tox.ini index e7c4926..28e7d7e 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,10 @@ [tox] -minversion = 3.7 -envlist = clean,{py36,py37,py38,py39,py310}-{linux,darwin,windows},report +minversion = 3.8 skip_missing_interpreters = True +envlist = + clean + report + py3{8,9,10,11,12,13}-{linux,darwin,windows} [testenv] deps = @@ -23,27 +26,34 @@ platform = windows: win32 usedevelop = True setenv = - {py36,py37,py38,py39,py310}-{linux,darwin,windows}: COVERAGE_FILE = .coverage.{envname} + PIP_USE_PEP517 = 1 + COVERAGE_FILE = .coverage.{envname} commands = {envpython} -m pytest --cov --cov-report=term-missing:skip-covered {posargs} depends = - {py36,py37,py38,py39,py310}-{linux,darwin,windows}: clean - report: {py36,py37,py38,py39,py310}-{linux,darwin,windows} + clean [testenv:clean] deps = coverage skip_install = true +setenv = + COVERAGE_FILE = commands = coverage erase +depends = [testenv:report] deps = coverage skip_install = true parallel_show_output = true +setenv = + COVERAGE_FILE = commands = coverage combine coverage report --show-missing coverage html coverage xml +depends = + py3{8,9,10,11,12,13}-{linux,darwin,windows} [pytest] testpaths = tests