Skip to content

Commit

Permalink
Turn shells into context managers
Browse files Browse the repository at this point in the history
  • Loading branch information
mwilliamson committed Jan 21, 2013
1 parent c976fd9 commit e5f48ca
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 12 deletions.
6 changes: 6 additions & 0 deletions spur/local.py
Expand Up @@ -9,6 +9,12 @@


class LocalShell(object):
def __enter__(self):
return self

def __exit__(self, *args):
pass

def upload_dir(self, source, dest, ignore=None):
shutil.copytree(source, dest, ignore=shutil.ignore_patterns(*ignore))

Expand Down
11 changes: 11 additions & 0 deletions spur/ssh.py
Expand Up @@ -32,6 +32,15 @@ def __init__(self, hostname, username, password=None, port=22, private_key_file=
self._private_key_file = private_key_file
self._client = None
self._connect_timeout = connect_timeout if not None else _ONE_MINUTE
self._closed = False

def __enter__(self):
return self

def __exit__(self, *args):
self._closed = True
if self._client is not None:
self._client.close()

def run(self, *args, **kwargs):
return self.spawn(*args, **kwargs).wait_for_result()
Expand Down Expand Up @@ -136,6 +145,8 @@ def _get_ssh_transport(self):

def _connect_ssh(self):
if self._client is None:
if self._closed:
raise RuntimeError("Shell is closed")
client = paramiko.SSHClient()
client.load_system_host_keys()
client.set_missing_host_key_policy(paramiko.WarningPolicy())
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
@@ -0,0 +1 @@

8 changes: 8 additions & 0 deletions tests/ssh_tests.py
Expand Up @@ -2,6 +2,7 @@

import spur
import spur.ssh
from .testing import create_ssh_shell


@istest
Expand All @@ -17,3 +18,10 @@ def try_connection():

assert_raises(spur.ssh.ConnectionError, try_connection)


@istest
def trying_to_use_ssh_shell_after_exit_results_in_error():
with create_ssh_shell() as shell:
pass

assert_raises(Exception, lambda: shell.run(["true"]))
14 changes: 14 additions & 0 deletions tests/testing.py
@@ -0,0 +1,14 @@
import os

import spur


def create_ssh_shell():
port_var = os.environ.get("TEST_SSH_PORT")
port = int(port_var) if port_var is not None else None
return spur.SshShell(
hostname=os.environ.get("TEST_SSH_HOSTNAME", "127.0.0.1"),
username=os.environ["TEST_SSH_USERNAME"],
password=os.environ["TEST_SSH_PASSWORD"],
port=port
)
17 changes: 5 additions & 12 deletions tests/tests.py
Expand Up @@ -8,29 +8,22 @@
from nose.tools import istest, assert_equal, assert_raises, assert_true

import spur
from .testing import create_ssh_shell


def test(func):
@functools.wraps(func)
def run_test():
for shell in _create_shells():
yield func, shell
with shell:
yield func, shell

def _create_shells():
return [
spur.LocalShell(),
_create_ssh_shell()
create_ssh_shell()
]

def _create_ssh_shell():
port_var = os.environ.get("TEST_SSH_PORT")
port = int(port_var) if port_var is not None else None
return spur.SshShell(
hostname=os.environ.get("TEST_SSH_HOSTNAME", "127.0.0.1"),
username=os.environ["TEST_SSH_USERNAME"],
password=os.environ["TEST_SSH_PASSWORD"],
port=port
)

return istest(run_test)

@test
Expand Down

0 comments on commit e5f48ca

Please sign in to comment.