diff --git a/dvc/scm/git/backend/dulwich/asyncssh_vendor.py b/dvc/scm/git/backend/dulwich/asyncssh_vendor.py index 3264c5c398..2e42f0f42f 100644 --- a/dvc/scm/git/backend/dulwich/asyncssh_vendor.py +++ b/dvc/scm/git/backend/dulwich/asyncssh_vendor.py @@ -41,11 +41,56 @@ async def _read(self, n: Optional[int] = None) -> bytes: read = sync_wrapper(_read) - def write(self, data: bytes): + async def _write(self, data: bytes): self.proc.stdin.write(data) + await self.proc.stdin.drain() - def close(self): + write = sync_wrapper(_write) + + async def _close(self): self.conn.close() + await self.conn.wait_closed() + + close = sync_wrapper(_close) + + +# NOTE: Github's SSH server does not strictly comply with the SSH protocol. +# When validating a public key using the rsa-sha2-256 or rsa-sha2-512 +# signature algorithms, RFC4252 + RFC8332 state that the server should respond +# with the same algorithm in SSH_MSG_USERAUTH_PK_OK. Github's server always +# returns "ssh-rsa" rather than the correct sha2 algorithm name (likely for +# backwards compatibility with old SSH client reasons). This behavior causes +# asyncssh to fail with a key-mismatch error (since asyncssh expects the server +# to behave properly). +# +# See also: +# https://www.ietf.org/rfc/rfc4252.txt +# https://www.ietf.org/rfc/rfc8332.txt +def _process_public_key_ok_gh(self, _pkttype, _pktid, packet): + from asyncssh.misc import ProtocolError + + algorithm = packet.get_string() + key_data = packet.get_string() + packet.check_end() + + # pylint: disable=protected-access + if ( + ( + algorithm == b"ssh-rsa" + and self._keypair.algorithm + not in ( + b"ssh-rsa", + b"rsa-sha2-256", + b"rsa-sha2-512", + ) + ) + or (algorithm != b"ssh-rsa" and algorithm != self._keypair.algorithm) + or key_data != self._keypair.public_data + ): + raise ProtocolError("Key mismatch") + + self.create_task(self._send_signed_request()) + return True class AsyncSSHVendor(BaseAsyncObject, SSHVendor): @@ -76,13 +121,20 @@ async def _run_command( key_filename: Optional path to private keyfile """ import asyncssh + from asyncssh.auth import MSG_USERAUTH_PK_OK, _ClientPublicKeyAuth + + # pylint: disable=protected-access + _ClientPublicKeyAuth._packet_handlers[ + MSG_USERAUTH_PK_OK + ] = _process_public_key_ok_gh conn = await asyncssh.connect( host, - port=port, + port=port if port is not None else (), username=username, password=password, - client_keys=[key_filename] if key_filename else [], + client_keys=[key_filename] if key_filename else (), + ignore_encrypted=not key_filename, known_hosts=None, encoding=None, ) diff --git a/tests/unit/scm/test_git.py b/tests/unit/scm/test_git.py index 3b86bd527c..085bc8572a 100644 --- a/tests/unit/scm/test_git.py +++ b/tests/unit/scm/test_git.py @@ -594,3 +594,29 @@ def test_pygit_checkout_subdir(tmp_dir, scm, git): with (tmp_dir / "dir").chdir(): git.checkout(rev) assert not (tmp_dir / "dir" / "bar").exists() + + +@pytest.mark.parametrize( + "algorithm", [b"ssh-rsa", b"rsa-sha2-256", b"rsa-sha2-512"] +) +def test_dulwich_github_compat(mocker, algorithm): + from asyncssh.misc import ProtocolError + + from dvc.scm.git.backend.dulwich.asyncssh_vendor import ( + _process_public_key_ok_gh, + ) + + key_data = b"foo" + auth = mocker.Mock( + _keypair=mocker.Mock(algorithm=algorithm, public_data=key_data), + ) + packet = mocker.Mock() + + with pytest.raises(ProtocolError): + strings = iter((b"ed21556", key_data)) + packet.get_string = lambda: next(strings) + _process_public_key_ok_gh(auth, None, None, packet) + + strings = iter((b"ssh-rsa", key_data)) + packet.get_string = lambda: next(strings) + _process_public_key_ok_gh(auth, None, None, packet)