Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions dvc/scm/git/backend/dulwich/asyncssh_vendor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,56 @@ async def _read(self, n: Optional[int] = None) -> bytes:

read = sync_wrapper(_read)

def write(self, data: bytes):
async def _write(self, data: bytes):
self.proc.stdin.write(data)
await self.proc.stdin.drain()

def close(self):
write = sync_wrapper(_write)

async def _close(self):
self.conn.close()
await self.conn.wait_closed()

close = sync_wrapper(_close)


# NOTE: Github's SSH server does not strictly comply with the SSH protocol.
# When validating a public key using the rsa-sha2-256 or rsa-sha2-512
# signature algorithms, RFC4252 + RFC8332 state that the server should respond
# with the same algorithm in SSH_MSG_USERAUTH_PK_OK. Github's server always
# returns "ssh-rsa" rather than the correct sha2 algorithm name (likely for
# backwards compatibility with old SSH client reasons). This behavior causes
# asyncssh to fail with a key-mismatch error (since asyncssh expects the server
# to behave properly).
#
# See also:
# https://www.ietf.org/rfc/rfc4252.txt
# https://www.ietf.org/rfc/rfc8332.txt
Comment on lines +57 to +68
Copy link
Contributor Author

@pmrowla pmrowla Nov 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should note that this appears to be specific to github. gitlab's SSH server returns the proper algorithm names in the auth responses.

This is also one of those cases where the RFC's do note that some existing SSH implementations in the wild may just always use "ssh-rsa", and that it's up to individual client/server implementations to decide whether or not to allow or reject those requests & responses.

def _process_public_key_ok_gh(self, _pkttype, _pktid, packet):
from asyncssh.misc import ProtocolError

algorithm = packet.get_string()
key_data = packet.get_string()
packet.check_end()

# pylint: disable=protected-access
if (
(
algorithm == b"ssh-rsa"
and self._keypair.algorithm
not in (
b"ssh-rsa",
b"rsa-sha2-256",
b"rsa-sha2-512",
)
)
or (algorithm != b"ssh-rsa" and algorithm != self._keypair.algorithm)
or key_data != self._keypair.public_data
):
raise ProtocolError("Key mismatch")

self.create_task(self._send_signed_request())
return True


class AsyncSSHVendor(BaseAsyncObject, SSHVendor):
Expand Down Expand Up @@ -76,13 +121,20 @@ async def _run_command(
key_filename: Optional path to private keyfile
"""
import asyncssh
from asyncssh.auth import MSG_USERAUTH_PK_OK, _ClientPublicKeyAuth

# pylint: disable=protected-access
_ClientPublicKeyAuth._packet_handlers[
MSG_USERAUTH_PK_OK
] = _process_public_key_ok_gh

conn = await asyncssh.connect(
host,
port=port,
port=port if port is not None else (),
username=username,
password=password,
client_keys=[key_filename] if key_filename else [],
client_keys=[key_filename] if key_filename else (),
ignore_encrypted=not key_filename,
known_hosts=None,
encoding=None,
)
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/scm/test_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,3 +594,29 @@ def test_pygit_checkout_subdir(tmp_dir, scm, git):
with (tmp_dir / "dir").chdir():
git.checkout(rev)
assert not (tmp_dir / "dir" / "bar").exists()


@pytest.mark.parametrize(
"algorithm", [b"ssh-rsa", b"rsa-sha2-256", b"rsa-sha2-512"]
)
def test_dulwich_github_compat(mocker, algorithm):
from asyncssh.misc import ProtocolError

from dvc.scm.git.backend.dulwich.asyncssh_vendor import (
_process_public_key_ok_gh,
)

key_data = b"foo"
auth = mocker.Mock(
_keypair=mocker.Mock(algorithm=algorithm, public_data=key_data),
)
packet = mocker.Mock()

with pytest.raises(ProtocolError):
strings = iter((b"ed21556", key_data))
packet.get_string = lambda: next(strings)
_process_public_key_ok_gh(auth, None, None, packet)

strings = iter((b"ssh-rsa", key_data))
packet.get_string = lambda: next(strings)
_process_public_key_ok_gh(auth, None, None, packet)