Skip to content

Commit

Permalink
Implement key_filename stuff re #3
Browse files Browse the repository at this point in the history
Includes a bunch of tightly related changes to network.py,
and tests.
  • Loading branch information
bitprophet committed Feb 3, 2012
1 parent 830fb62 commit 5e704fd
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 21 deletions.
17 changes: 7 additions & 10 deletions fabric/contrib/project.py
Expand Up @@ -7,7 +7,7 @@
from datetime import datetime
from tempfile import mkdtemp

from fabric.network import needs_host
from fabric.network import needs_host, key_filenames
from fabric.operations import local, run, put
from fabric.state import env, output

Expand Down Expand Up @@ -85,14 +85,12 @@ def rsync_project(remote_dir, local_dir=None, exclude=(), delete=False,
exclusions = tuple([str(s).replace('"', '\\\\"') for s in exclude])
# Honor SSH key(s)
key_string = ""
if env.key_filename:
keys = env.key_filename
# For ease of use, coerce stringish key filename into list
if not isinstance(env.key_filename, (list, tuple)):
keys = [keys]
keys = key_filenames()
if keys:
key_string = "-i " + " -i ".join(keys)
# Honor nonstandard port
port_string = ("-p %s" % env.port) if (env.port != '22') else ""
# Port
user, host, port = normalize(env.host_string)
port_string = "-p %s" % port
# RSH
rsh_string = ""
if key_string or port_string:
Expand All @@ -109,8 +107,7 @@ def rsync_project(remote_dir, local_dir=None, exclude=(), delete=False,
if local_dir is None:
local_dir = '../' + getcwd().split(sep)[-1]
# Create and run final command string
cmd = "rsync %s %s %s@%s:%s" % (options, local_dir, env.user,
env.host, remote_dir)
cmd = "rsync %s %s %s@%s:%s" % (options, local_dir, user, host, remote_dir)
if output.running:
print("[%s] rsync_project: %s" % (env.host_string, cmd))
return local(cmd)
Expand Down
49 changes: 40 additions & 9 deletions fabric/network.py
Expand Up @@ -6,6 +6,7 @@

from functools import wraps
import getpass
import os
import re
import threading
import time
Expand Down Expand Up @@ -97,24 +98,54 @@ def __contains__(self, key):
return dict.__contains__(self, normalize_to_string(key))


def ssh_config():
def ssh_config(host_string=None):
"""
Load (memoize) and parse the configured SSH config file.
Return ssh configuration dict for current env.host_string host value.
Assumes that if it's been called, the SSH config option has been set to
True, and aborts if it can't load the requested file.
Memoizes the loaded SSH config file, but not the specific per-host results.
May give an explicit host string as ``host_string``.
"""
from fabric.state import env
if '_ssh_config' not in env:
try:
conf = ssh.SSHConfig()
path = env.ssh_config_path
path = os.path.expanduser(env.ssh_config_path)
with open(path) as fd:
conf.parse(fd)
env._ssh_config = conf
except IOError, e:
abort("Unable to load SSH config file '%s'" % path)
return env._ssh_config
host = parse_host_string(host_string or env.host_string)['host']
return env._ssh_config.lookup(host)


def key_filenames():
"""
Returns list of SSH key filenames for the current env.host_string.
Takes into account ssh_config and env.key_filename, including normalization
to a list. Also performs ``os.path.expanduser`` expansion on any key
filenames.
"""
from fabric.state import env
keys = env.key_filename
# For ease of use, coerce stringish key filename into list
if not isinstance(env.key_filename, (list, tuple)):
keys = [keys]
# Strip out any empty strings (such as the default value...meh)
keys = filter(bool, keys)
# Honor SSH config
if env.use_ssh_config:
# TODO: fix ssh so it correctly treats IdentityFile as a list
conf = ssh_config()
if 'identityfile' in conf:
keys.append(conf['identityfile'])
return map(os.path.expanduser, keys)


def parse_host_string(host_string):
return host_regex.match(host_string).groupdict()


def normalize(host_string, omit_port=False):
Expand All @@ -133,14 +164,14 @@ def normalize(host_string, omit_port=False):
return ('', '') if omit_port else ('', '', '')
# Parse host string (need this early on to look up host-specific ssh_config
# values)
r = host_regex.match(host_string).groupdict()
r = parse_host_string(host_string)
host = r['host']
# Env values
user = env.user
port = env.port
# SSH config data
if env.use_ssh_config:
conf = ssh_config().lookup(host)
conf = ssh_config(host_string)
# Only use ssh_config values if the env value appears unmodified from
# the true defaults. If the user has tweaked them, that new value
# takes precedence.
Expand Down Expand Up @@ -250,7 +281,7 @@ def connect(user, host, port):
port=int(port),
username=user,
password=password,
key_filename=env.key_filename,
key_filename=key_filenames(),
timeout=env.timeout,
allow_agent=not env.no_agent,
look_for_keys=not env.no_keys
Expand Down
2 changes: 1 addition & 1 deletion fabric/operations.py
Expand Up @@ -16,7 +16,7 @@

from fabric.context_managers import settings, char_buffered, hide
from fabric.io import output_loop, input_loop
from fabric.network import needs_host, ssh
from fabric.network import needs_host, ssh, ssh_config
from fabric.sftp import SFTP
from fabric.state import env, connections, output, win32, default_channel
from fabric.thread_handling import ThreadHandler
Expand Down
1 change: 1 addition & 0 deletions fabric/state.py
Expand Up @@ -285,6 +285,7 @@ def _rc_path():
'roles': [],
'roledefs': {},
'skip_bad_hosts': False,
'ssh_config_path': '~/.ssh/config',
# -S so sudo accepts passwd via stdin, -p with our known-value prompt for
# later detection (thus %s -- gets filled with env.sudo_prompt at runtime)
'sudo_prefix': "sudo -S -p '%s' ",
Expand Down
2 changes: 2 additions & 0 deletions tests/support/ssh_config
@@ -1,10 +1,12 @@
Host *
User satan
Port 666
IdentityFile foobar.pub

Host myhost
User neighbor
Port 664
IdentityFile neighbor.pub

Host myalias
HostName otherhost
55 changes: 54 additions & 1 deletion tests/test_network.py
Expand Up @@ -12,7 +12,7 @@

from fabric.context_managers import settings, hide, show
from fabric.network import (HostConnectionCache, join_host_strings, normalize,
denormalize)
denormalize, key_filenames)
from fabric.io import output_loop
import fabric.network # So I can call patch_object correctly. Sigh.
from fabric.state import env, output, _get_system_username
Expand Down Expand Up @@ -570,3 +570,56 @@ def test_real_connection(self):
host_string='testserver',
):
ok_(run("ls /simple").succeeded)


class TestKeyFilenames(FabricTest):
def test_empty_everything(self):
"""
No env.key_filename and no ssh_config = empty list
"""
with settings(use_ssh_config=False):
with settings(key_filename=""):
eq_(key_filenames(), [])
with settings(key_filename=[]):
eq_(key_filenames(), [])

def test_just_env(self):
"""
Valid env.key_filename and no ssh_config = just env
"""
with settings(use_ssh_config=False):
with settings(key_filename="mykey"):
eq_(key_filenames(), ["mykey"])
with settings(key_filename=["foo", "bar"]):
eq_(key_filenames(), ["foo", "bar"])

def test_just_ssh_config(self):
"""
No env.key_filename + valid ssh_config = ssh value
"""
with settings(use_ssh_config=True, ssh_config_path=support("ssh_config")):
for val in ["", []]:
with settings(key_filename=val):
eq_(key_filenames(), ["foobar.pub"])

def test_both(self):
"""
Both env.key_filename + valid ssh_config = both show up w/ env var first
"""
with settings(use_ssh_config=True, ssh_config_path=support("ssh_config")):
with settings(key_filename="bizbaz.pub"):
eq_(key_filenames(), ["bizbaz.pub", "foobar.pub"])
with settings(key_filename=["bizbaz.pub", "whatever.pub"]):
expected = ["bizbaz.pub", "whatever.pub", "foobar.pub"]
eq_(key_filenames(), expected)

def test_specific_host(self):
"""
SSH lookup aspect should correctly select per-host value
"""
with settings(
use_ssh_config=True,
ssh_config_path=support("ssh_config"),
host_string="myhost"
):
eq_(key_filenames(), ["neighbor.pub"])

0 comments on commit 5e704fd

Please sign in to comment.