Skip to content
This repository has been archived by the owner on Jun 1, 2018. It is now read-only.

Commit

Permalink
fix #113, make Reader.peek() work on Python 3
Browse files Browse the repository at this point in the history
  • Loading branch information
mhils committed Feb 1, 2016
1 parent 7c83a70 commit bda49dd
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
30 changes: 25 additions & 5 deletions netlib/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@

version_check.check_pyopenssl_version()

if six.PY2:
socket_fileobject = socket._fileobject
else:
socket_fileobject = socket.SocketIO

EINTR = 4

Expand Down Expand Up @@ -270,7 +274,7 @@ def peek(self, length):
TlsException if there was an error with pyOpenSSL.
NotImplementedError if the underlying file object is not a (pyOpenSSL) socket
"""
if isinstance(self.o, socket._fileobject):
if isinstance(self.o, socket_fileobject):
try:
return self.o._sock.recv(length, socket.MSG_PEEK)
except socket.error as e:
Expand Down Expand Up @@ -423,8 +427,17 @@ class _Connection(object):
def __init__(self, connection):
if connection:
self.connection = connection
self.rfile = Reader(self.connection.makefile('rb', self.rbufsize))
self.wfile = Writer(self.connection.makefile('wb', self.wbufsize))
# Ideally, we would use the Buffered IO in Python 3 by default.
# Unfortunately, the implementation of .peek() is broken for n>1 bytes,
# as it may just return what's left in the buffer and not all the bytes we want.
# As a workaround, we just use unbuffered sockets directly.
# https://mail.python.org/pipermail/python-dev/2009-June/089986.html
if six.PY2:
self.rfile = Reader(self.connection.makefile('rb', self.rbufsize))
self.wfile = Writer(self.connection.makefile('wb', self.wbufsize))
else:
self.rfile = Reader(socket.SocketIO(self.connection, "rb"))
self.wfile = Writer(socket.SocketIO(self.connection, "wb"))
else:
self.connection = None
self.rfile = None
Expand Down Expand Up @@ -663,8 +676,15 @@ def connect(self):
connection.connect(self.address())
if not self.source_address:
self.source_address = Address(connection.getsockname())
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
self.wfile = Writer(connection.makefile('wb', self.wbufsize))

# See _Connection.__init__ why we do this dance.
if six.PY2:
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
else:
self.rfile = Reader(socket.SocketIO(connection, "rb"))
self.wfile = Writer(socket.SocketIO(connection, "wb"))

except (socket.error, IOError) as err:
raise TcpException(
'Error connecting to "%s": %s' %
Expand Down
2 changes: 1 addition & 1 deletion test/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def test_peek(self):
c.wfile.write(testval)
c.wfile.flush()

assert c.rfile.peek(4) == "peek"[:4]
assert c.rfile.peek(4) == b"peek"[:4]
assert c.rfile.peek(6) == testval


Expand Down

0 comments on commit bda49dd

Please sign in to comment.