Skip to content

Commit

Permalink
Merge pull request #53 from dguerri/master
Browse files Browse the repository at this point in the history
Add session persistence
  • Loading branch information
James Bardin committed Apr 24, 2015
2 parents 2f83933 + f361317 commit a57e683
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 11 deletions.
26 changes: 26 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,32 @@ Example
scp.put('test.txt', 'test2.txt')
scp.get('test2.txt')
scp.close()
.. code-block::
$ md5sum test.txt test2.txt
fc264c65fb17b7db5237cf7ce1780769 test.txt
fc264c65fb17b7db5237cf7ce1780769 test2.txt
Using 'with' keyword
------------------

.. code-block:: python
from paramiko import SSHClient
from scp import SCPClient
ssh = SSHClient()
ssh.load_system_host_keys()
ssh.connect('example.com')
with SCPClient(ssh.get_transport()) as scp:
scp.put('test.txt', 'test2.txt')
scp.get('test2.txt')
.. code-block::
$ md5sum test.txt test2.txt
Expand Down
26 changes: 20 additions & 6 deletions scp.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ def __init__(self, transport, buff_size=16384, socket_timeout=5.0,
self.sanitize = sanitize
self._dirtimes = {}

def __enter__(self):
self.channel = self._open()
return self

def __exit__(self, type, value, traceback):
self.close()

def put(self, files, remote_path=b'.',
recursive=False, preserve_times=False):
"""
Expand All @@ -130,7 +137,7 @@ def put(self, files, remote_path=b'.',
@type preserve_times: bool
"""
self.preserve_times = preserve_times
self.channel = self.transport.open_session()
self.channel = self._open()
self._pushed = 0
self.channel.settimeout(self.socket_timeout)
scp_command = (b'scp -t ', b'scp -r -t ')[recursive]
Expand All @@ -146,9 +153,6 @@ def put(self, files, remote_path=b'.',
else:
self._send_files(files)

if self.channel:
self.channel.close()

def get(self, remote_path, local_path='',
recursive=False, preserve_times=False):
"""
Expand Down Expand Up @@ -181,7 +185,7 @@ def get(self, remote_path, local_path='',
asunicode(self._recv_dir))
rcsv = (b'', b' -r')[recursive]
prsv = (b'', b' -p')[preserve_times]
self.channel = self.transport.open_session()
self.channel = self._open()
self._pushed = 0
self.channel.settimeout(self.socket_timeout)
self.channel.exec_command(b"scp" +
Expand All @@ -191,8 +195,18 @@ def get(self, remote_path, local_path='',
b' '.join(remote_path))
self._recv_all()

if self.channel:
def _open(self):
"""open a scp channel"""
if self.channel is None:
self.channel = self.transport.open_session()

return self.channel

def close(self):
"""close scp channel"""
if self.channel is not None:
self.channel.close()
self.channel = None

def _read_stats(self, name):
"""return just the file stats needed for scp"""
Expand Down
11 changes: 6 additions & 5 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,10 @@ def download_test(self, filename, recursive, destination=None,
os.mkdir(temp_in)
os.chdir(temp_in)
try:
scp = SCPClient(self.ssh.get_transport())
scp.get(filename, destination if destination is not None else u'.',
preserve_times=True, recursive=recursive)
with SCPClient(self.ssh.get_transport()) as scp:
scp.get(filename,
destination if destination is not None else u'.',
preserve_times=True, recursive=recursive)
actual = []
def listdir(path, fpath):
for name in os.listdir(fpath):
Expand Down Expand Up @@ -204,8 +205,8 @@ def upload_test(self, filenames, recursive, expected=[]):
previous = os.getcwd()
try:
os.chdir(self._temp)
scp = SCPClient(self.ssh.get_transport())
scp.put(filenames, destination, recursive)
with SCPClient(self.ssh.get_transport()) as scp:
scp.put(filenames, destination, recursive)

chan = self.ssh.get_transport().open_session()
chan.exec_command(
Expand Down

0 comments on commit a57e683

Please sign in to comment.