diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index 9a90065b44..0ef060efcf 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import os import getpass import logging @@ -38,18 +39,26 @@ def __init__(self, repo, config): parsed = urlparse(self.url) self.host = parsed.hostname + + user_ssh_config = self._load_user_ssh_config(self.host) + + self.host = user_ssh_config.get("hostname", self.host) self.user = ( config.get(Config.SECTION_REMOTE_USER) or parsed.username + or user_ssh_config.get("user") or getpass.getuser() ) self.prefix = parsed.path or "/" self.port = ( config.get(Config.SECTION_REMOTE_PORT) or parsed.port + or self._try_get_ssh_config_port(user_ssh_config) or self.DEFAULT_PORT ) - self.keyfile = config.get(Config.SECTION_REMOTE_KEY_FILE, None) + self.keyfile = config.get( + Config.SECTION_REMOTE_KEY_FILE + ) or self._try_get_ssh_config_keyfile(user_ssh_config) self.timeout = config.get(Config.SECTION_REMOTE_TIMEOUT, self.TIMEOUT) self.password = config.get(Config.SECTION_REMOTE_PASSWORD, None) self.ask_password = config.get( @@ -63,6 +72,35 @@ def __init__(self, repo, config): "port": self.port, } + @staticmethod + def ssh_config_filename(): + return os.path.expanduser(os.path.join("~", ".ssh", "config")) + + @staticmethod + def _load_user_ssh_config(hostname): + user_config_file = RemoteSSH.ssh_config_filename() + user_ssh_config = dict() + if hostname and os.path.exists(user_config_file): + with open(user_config_file) as f: + ssh_config = paramiko.SSHConfig() + ssh_config.parse(f) + user_ssh_config = ssh_config.lookup(hostname) + return user_ssh_config + + @staticmethod + def _try_get_ssh_config_port(user_ssh_config): + try: + return int(user_ssh_config.get("port")) + except (ValueError, TypeError): + return None + + @staticmethod + def _try_get_ssh_config_keyfile(user_ssh_config): + identity_file = user_ssh_config.get("identityfile") + if identity_file and len(identity_file) > 0: + return identity_file[0] + return None + def ssh(self, host=None, user=None, port=None, **kwargs): logger.debug( "Establishing ssh connection with '{host}' " diff --git a/tests/requirements.txt b/tests/requirements.txt index fc290e9450..c7e1627afe 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -4,7 +4,7 @@ pytest-cov>=2.6.1 pytest-xdist>=1.26.1 pytest-mock>=1.10.4 flaky>=3.5.3 -mock>=2.0.0 +mock>=3.0.0 xmltodict>=0.11.0 awscli>=1.16.125 google-compute-engine diff --git a/tests/unit/remote/ssh/test_ssh.py b/tests/unit/remote/ssh/test_ssh.py index 9afd9fe7df..cc66b3adbf 100644 --- a/tests/unit/remote/ssh/test_ssh.py +++ b/tests/unit/remote/ssh/test_ssh.py @@ -1,28 +1,16 @@ import getpass +import os +import sys from unittest import TestCase +from mock import patch, mock_open +import pytest + from dvc.remote.ssh import RemoteSSH class TestRemoteSSH(TestCase): - def test_user(self): - url = "ssh://127.0.0.1:/path/to/dir" - config = {"url": url} - remote = RemoteSSH(None, config) - self.assertEqual(remote.user, getpass.getuser()) - - user = "test1" - url = "ssh://{}@127.0.0.1:/path/to/dir".format(user) - config = {"url": url} - remote = RemoteSSH(None, config) - self.assertEqual(remote.user, user) - - user = "test2" - config["user"] = user - remote = RemoteSSH(None, config) - self.assertEqual(remote.user, user) - def test_url(self): user = "test" host = "123.45.67.89" @@ -56,19 +44,127 @@ def test_no_path(self): remote = RemoteSSH(None, config) self.assertEqual(remote.prefix, "/") - def test_port(self): - url = "ssh://127.0.0.1/path/to/dir" - config = {"url": url} - remote = RemoteSSH(None, config) - self.assertEqual(remote.port, remote.DEFAULT_PORT) - - port = 1234 - url = "ssh://127.0.0.1:{}/path/to/dir".format(port) - config = {"url": url} - remote = RemoteSSH(None, config) - self.assertEqual(remote.port, port) - port = 4321 - config["port"] = port - remote = RemoteSSH(None, config) - self.assertEqual(remote.port, port) +mock_ssh_config = """ +Host example.com + User ubuntu + HostName 1.2.3.4 + Port 1234 + IdentityFile ~/.ssh/not_default.key +""" + +if sys.version_info.major == 3: + builtin_module_name = "builtins" +else: + builtin_module_name = "__builtin__" + + +@pytest.mark.parametrize( + "config,expected_host", + [ + ({"url": "ssh://example.com"}, "1.2.3.4"), + ({"url": "ssh://not_in_ssh_config.com"}, "not_in_ssh_config.com"), + ], +) +@patch("os.path.exists", return_value=True) +@patch( + "{}.open".format(builtin_module_name), + new_callable=mock_open, + read_data=mock_ssh_config, +) +def test_ssh_host_override_from_config( + mock_file, mock_exists, config, expected_host +): + remote = RemoteSSH(None, config) + + mock_exists.assert_called_with(RemoteSSH.ssh_config_filename()) + mock_file.assert_called_with(RemoteSSH.ssh_config_filename()) + assert remote.host == expected_host + + +@pytest.mark.parametrize( + "config,expected_user", + [ + ({"url": "ssh://test1@example.com"}, "test1"), + ({"url": "ssh://example.com", "user": "test2"}, "test2"), + ({"url": "ssh://example.com"}, "ubuntu"), + ({"url": "ssh://test1@not_in_ssh_config.com"}, "test1"), + ( + {"url": "ssh://test1@not_in_ssh_config.com", "user": "test2"}, + "test2", + ), + ({"url": "ssh://not_in_ssh_config.com"}, getpass.getuser()), + ], +) +@patch("os.path.exists", return_value=True) +@patch( + "{}.open".format(builtin_module_name), + new_callable=mock_open, + read_data=mock_ssh_config, +) +def test_ssh_user(mock_file, mock_exists, config, expected_user): + remote = RemoteSSH(None, config) + + mock_exists.assert_called_with(RemoteSSH.ssh_config_filename()) + mock_file.assert_called_with(RemoteSSH.ssh_config_filename()) + assert remote.user == expected_user + + +@pytest.mark.parametrize( + "config,expected_port", + [ + ({"url": "ssh://example.com:2222"}, 2222), + ({"url": "ssh://example.com"}, 1234), + ({"url": "ssh://example.com", "port": 4321}, 4321), + ({"url": "ssh://not_in_ssh_config.com"}, RemoteSSH.DEFAULT_PORT), + ({"url": "ssh://not_in_ssh_config.com:2222"}, 2222), + ({"url": "ssh://not_in_ssh_config.com:2222", "port": 4321}, 4321), + ], +) +@patch("os.path.exists", return_value=True) +@patch( + "{}.open".format(builtin_module_name), + new_callable=mock_open, + read_data=mock_ssh_config, +) +def test_ssh_port(mock_file, mock_exists, config, expected_port): + remote = RemoteSSH(None, config) + + mock_exists.assert_called_with(RemoteSSH.ssh_config_filename()) + mock_file.assert_called_with(RemoteSSH.ssh_config_filename()) + assert remote.port == expected_port + + +@pytest.mark.parametrize( + "config,expected_keyfile", + [ + ( + {"url": "ssh://example.com", "keyfile": "dvc_config.key"}, + "dvc_config.key", + ), + ( + {"url": "ssh://example.com"}, + os.path.expanduser("~/.ssh/not_default.key"), + ), + ( + { + "url": "ssh://not_in_ssh_config.com", + "keyfile": "dvc_config.key", + }, + "dvc_config.key", + ), + ({"url": "ssh://not_in_ssh_config.com"}, None), + ], +) +@patch("os.path.exists", return_value=True) +@patch( + "{}.open".format(builtin_module_name), + new_callable=mock_open, + read_data=mock_ssh_config, +) +def test_ssh_keyfile(mock_file, mock_exists, config, expected_keyfile): + remote = RemoteSSH(None, config) + + mock_exists.assert_called_with(RemoteSSH.ssh_config_filename()) + mock_file.assert_called_with(RemoteSSH.ssh_config_filename()) + assert remote.keyfile == expected_keyfile