Permalink
Browse files

Merge pull request #40 from DouglasTurk/eintr_fixes

Handle/fix handling of EINTR errors in a few places.
  • Loading branch information...
2 parents 4cb9fed + 351bdb7 commit 6560da828db898a3f5f6b2bb9ce0b8aa0424f987 @bitprophet committed Sep 10, 2012
Showing with 52 additions and 8 deletions.
  1. +2 −1 ssh/agent.py
  2. +2 −1 ssh/client.py
  3. +11 −5 ssh/packet.py
  4. +2 −1 ssh/transport.py
  5. +9 −0 ssh/util.py
  6. +26 −0 tests/test_util.py
View
@@ -35,6 +35,7 @@
from ssh.pkey import PKey
from ssh.channel import Channel
from ssh.common import io_sleep
+from ssh.util import retry_on_signal
SSH2_AGENTC_REQUEST_IDENTITIES, SSH2_AGENT_IDENTITIES_ANSWER, \
SSH2_AGENTC_SIGN_REQUEST, SSH2_AGENT_SIGN_RESPONSE = range(11, 15)
@@ -202,7 +203,7 @@ def connect(self):
if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'):
conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
- conn.connect(os.environ['SSH_AUTH_SOCK'])
+ retry_on_signal(lambda: conn.connect(os.environ['SSH_AUTH_SOCK']))
except:
# probably a dangling env var: the ssh agent is gone
return
View
@@ -34,6 +34,7 @@
from ssh.rsakey import RSAKey
from ssh.ssh_exception import SSHException, BadHostKeyException
from ssh.transport import Transport
+from ssh.util import retry_on_signal
SSH_PORT = 22
@@ -293,7 +294,7 @@ def connect(self, hostname, port=SSH_PORT, username=None, password=None, pkey=No
sock.settimeout(timeout)
except:
pass
- sock.connect(addr)
+ retry_on_signal(lambda: sock.connect(addr))
t = self._transport = Transport(sock)
t.use_compression(compress=compress)
if self._log_channel is not None:
View
@@ -241,23 +241,23 @@ def read_all(self, n, check_rekey=False):
def write_all(self, out):
self.__keepalive_last = time.time()
while len(out) > 0:
- got_timeout = False
+ retry_write = False
try:
n = self.__socket.send(out)
except socket.timeout:
- got_timeout = True
+ retry_write = True
except socket.error, e:
if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN):
- got_timeout = True
+ retry_write = True
elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR):
# syscall interrupted; try again
- pass
+ retry_write = True
else:
n = -1
except Exception:
# could be: (32, 'Broken pipe')
n = -1
- if got_timeout:
+ if retry_write:
n = 0
if self.__closed:
n = -1
@@ -469,6 +469,12 @@ def _read_timeout(self, timeout):
break
except socket.timeout:
pass
+ except EnvironmentError, e:
+ if ((type(e.args) is tuple) and (len(e.args) > 0) and
+ (e.args[0] == errno.EINTR)):
+ pass
+ else:
+ raise
if self.__closed:
raise EOFError()
now = time.time()
View
@@ -45,6 +45,7 @@
from ssh.server import ServerInterface
from ssh.sftp_client import SFTPClient
from ssh.ssh_exception import SSHException, BadAuthenticationType, ChannelException
+from ssh.util import retry_on_signal
from Crypto import Random
from Crypto.Cipher import Blowfish, AES, DES3, ARC4
@@ -289,7 +290,7 @@ def __init__(self, sock):
addr = sockaddr
sock = socket.socket(af, socket.SOCK_STREAM)
try:
- sock.connect((hostname, port))
+ retry_on_signal(lambda: sock.connect((hostname, port)))
except socket.error, e:
reason = str(e)
else:
View
@@ -24,6 +24,7 @@
import array
from binascii import hexlify, unhexlify
+import errno
import sys
import struct
import traceback
@@ -270,6 +271,14 @@ def get_logger(name):
l.addFilter(_pfilter)
return l
+def retry_on_signal(function):
+ """Retries function until it doesn't raise an EINTR error"""
+ while True:
+ try:
+ return function()
+ except EnvironmentError, e:
+ if e.errno != errno.EINTR:
+ raise
class Counter (object):
"""Stateful counter for CTR mode crypto"""
View
@@ -22,6 +22,7 @@
from binascii import hexlify
import cStringIO
+import errno
import os
import unittest
from Crypto.Hash import SHA
@@ -177,3 +178,28 @@ def test_7_host_config_expose_issue_33(self):
ssh.util.lookup_ssh_host_config(host, config),
{'hostname': host, 'port': '22'}
)
+
+ def test_8_eintr_retry(self):
+ self.assertEquals('foo', ssh.util.retry_on_signal(lambda: 'foo'))
+
+ # Variables that are set by raises_intr
+ intr_errors_remaining = [3]
+ call_count = [0]
+ def raises_intr():
+ call_count[0] += 1
+ if intr_errors_remaining[0] > 0:
+ intr_errors_remaining[0] -= 1
+ raise IOError(errno.EINTR, 'file', 'interrupted system call')
+ self.assertTrue(ssh.util.retry_on_signal(raises_intr) is None)
+ self.assertEquals(0, intr_errors_remaining[0])
+ self.assertEquals(4, call_count[0])
+
+ def raises_ioerror_not_eintr():
+ raise IOError(errno.ENOENT, 'file', 'file not found')
+ self.assertRaises(IOError,
+ lambda: ssh.util.retry_on_signal(raises_ioerror_not_eintr))
+
+ def raises_other_exception():
+ raise AssertionError('foo')
+ self.assertRaises(AssertionError,
+ lambda: ssh.util.retry_on_signal(raises_other_exception))

0 comments on commit 6560da8

Please sign in to comment.