Skip to content

Commit

Permalink
Implicitly Convert Unicode to String for send and sendall (#256)
Browse files Browse the repository at this point in the history
Fixes #254 by converting unicode to string. This copies CPython in
allowing unicode, even though the best practice is to explicitly only
pass strings (or bytes) to socket.

We add test cases that are based on the existing cases for socket.
  • Loading branch information
TheMatt2 committed Aug 31, 2023
1 parent fb29aa5 commit c020c5e
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 4 deletions.
1 change: 1 addition & 0 deletions ACKNOWLEDGMENTS
Expand Up @@ -194,6 +194,7 @@ Jython in ways large and small, in no particular order:
Dominik Broj
shoaniki (GitHub name)
Zhihong (Ted) Yu
Matthew Schweiss (GitHub: TheMatt2)

Local Variables:
mode: indented-text
Expand Down
8 changes: 7 additions & 1 deletion Lib/_socket.py
Expand Up @@ -940,7 +940,7 @@ def _peer_closed(x):
self.incoming.put(_PEER_CLOSED)
self._notify_selectors(hangup=True)

log.debug("Add _peer_closed to channel close", extra={"sock": self})
log.debug("Add _peer_closed to channel close", extra={"sock": self})
self.channel.closeFuture().addListener(_peer_closed)

def connect(self, addr):
Expand Down Expand Up @@ -1191,6 +1191,9 @@ def _verify_channel(self):

@raises_java_exception
def send(self, data, flags=0):
if isinstance(data, unicode):
data = data.encode()

self._verify_channel()
if log.isEnabledFor(logging.DEBUG):
log.debug("Sending data <<<{!r:.20}>>>".format(data), extra={"sock": self})
Expand All @@ -1216,6 +1219,9 @@ def send(self, data, flags=0):
return len(buf)

def sendall(self, data, flags=0):
if isinstance(data, unicode):
data = data.encode()

with memoryview(data) as buf:
# Limit the amount per send to L to control data movement
k, n, L = 0, len(buf), 8192
Expand Down
193 changes: 191 additions & 2 deletions Lib/test/test_socket_jy.py
Expand Up @@ -9,8 +9,7 @@
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from SocketServer import ThreadingMixIn
from test import test_support
from test.test_socket import SocketConnectedTest

from test.test_socket import SocketConnectedTest, ThreadedUDPSocketTest

def data_file(*name):
return os.path.join(os.path.dirname(__file__), *name)
Expand All @@ -19,6 +18,8 @@ def data_file(*name):
ONLYCERT = data_file("ssl_cert.pem")
ONLYKEY = data_file("ssl_key.pem")

MSG = 'Michael Gilfix was here\n'

def start_server():
server_address = ('127.0.0.1', 0)

Expand Down Expand Up @@ -203,12 +204,200 @@ def _testSendAllBuffer(self):
self.serv_conn.sendall(big)


class BasicTCPUnicodeTest(SocketConnectedTest):

def __init__(self, methodName='runTest'):
SocketConnectedTest.__init__(self, methodName=methodName)

def testRecv(self):
# Testing large receive over TCP
msg = self.cli_conn.recv(1024)
self.assertEqual(msg, MSG)
self.assertEqual(type(msg), str)

def _testRecv(self):
self.serv_conn.send(MSG.decode())

def testRecvTimeoutMode(self):
# Do this test in timeout mode, because the code path is different
self.cli_conn.settimeout(10)
msg = self.cli_conn.recv(1024)
self.assertEqual(msg, MSG)
self.assertEqual(type(msg), str)

def _testRecvTimeoutMode(self):
self.serv_conn.settimeout(10)
self.serv_conn.send(MSG.decode())

def testOverFlowRecv(self):
# Testing receive in chunks over TCP
seg1 = self.cli_conn.recv(len(MSG) - 3)
seg2 = self.cli_conn.recv(1024)
msg = seg1 + seg2
self.assertEqual(msg, MSG)
self.assertEqual(type(msg), str)

def _testOverFlowRecv(self):
self.serv_conn.send(MSG.decode())

def testRecvFrom(self):
# Testing large recvfrom() over TCP
msg, addr = self.cli_conn.recvfrom(1024)
self.assertEqual(msg, MSG)
self.assertEqual(type(msg), str)

def _testRecvFrom(self):
self.serv_conn.send(MSG.decode())

def testOverFlowRecvFrom(self):
# Testing recvfrom() in chunks over TCP
seg1, addr = self.cli_conn.recvfrom(len(MSG)-3)
seg2, addr = self.cli_conn.recvfrom(1024)
msg = seg1 + seg2
self.assertEqual(msg, MSG)
self.assertEqual(type(msg), str)

def _testOverFlowRecvFrom(self):
self.serv_conn.send(MSG.decode())

def testSendAll(self):
# Testing sendall() with a 2048 byte string over TCP
msg = ''
while 1:
read = self.cli_conn.recv(1024)
if not read:
break
msg += read
self.assertEqual(msg, 'f' * 2048)
self.assertEqual(type(msg), str)

def _testSendAll(self):
big_chunk = u'f' * 2048
self.serv_conn.sendall(big_chunk)

def _testFromFd(self):
self.serv_conn.send(MSG.decode())

def testShutdown(self):
# Testing shutdown()
msg = self.cli_conn.recv(1024)
self.assertEqual(msg, MSG)
self.assertEqual(type(msg), str)

def _testShutdown(self):
self.serv_conn.send(MSG.decode())
self.serv_conn.shutdown(2)

def testSendAfterRemoteClose(self):
self.cli_conn.close()

def _testSendAfterRemoteClose(self):
for x in range(5):
try:
self.serv_conn.send(u"spam")
except socket.error, se:
self.failUnlessEqual(se[0], errno.ECONNRESET)
return
except Exception, x:
self.fail("Sending on remotely closed socket raised wrong exception: %s" % x)
time.sleep(0.5)
self.fail("Sending on remotely closed socket should have raised exception")

def testDup(self):
msg = self.cli_conn.recv(len(MSG))
self.assertEqual(msg, MSG)
self.assertEqual(type(msg), str)

dup_conn = self.cli_conn.dup()
msg = dup_conn.recv(len(u'and ' + MSG))
self.assertEqual(msg, u'and ' + MSG)
dup_conn.close() # need to ensure all sockets are closed

def _testDup(self):
self.serv_conn.send(MSG.decode())
self.serv_conn.send(u'and ' + MSG)


class BasicUDPUnicodeTest(ThreadedUDPSocketTest):

def __init__(self, methodName='runTest'):
ThreadedUDPSocketTest.__init__(self, methodName=methodName)

def testSendtoAndRecv(self):
# Testing sendto() and recv() over UDP
msg = self.serv.recv(len(MSG))
self.assertEqual(msg, MSG)
self.assertEqual(type(msg), str)

def _testSendtoAndRecv(self):
self.cli.sendto(MSG.decode(), 0, (self.HOST.decode(), self.PORT))

def testSendtoAndRecvTimeoutMode(self):
# Need to test again in timeout mode, which follows
# a different code path
self.serv.settimeout(1)
msg = self.serv.recv(len(MSG))
self.assertEqual(msg, MSG)
self.assertEqual(type(msg), str)

def _testSendtoAndRecvTimeoutMode(self):
self.cli.settimeout(10)
self.cli.sendto(MSG.decode(), 0, (self.HOST.decode(), self.PORT))

def testSendAndRecv(self):
# Testing send() and recv() over connect'ed UDP
msg = self.serv.recv(len(MSG))
self.assertEqual(msg, MSG)
self.assertEqual(type(msg), str)

def _testSendAndRecv(self):
self.cli.connect( (self.HOST.decode(), self.PORT) )
self.cli.send(MSG.decode(), 0)

def testSendAndRecvTimeoutMode(self):
# Need to test again in timeout mode, which follows
# a different code path
self.serv.settimeout(5)
# Testing send() and recv() over connect'ed UDP
msg = self.serv.recv(len(MSG))
self.assertEqual(msg, MSG)
self.assertEqual(type(msg), str)

def _testSendAndRecvTimeoutMode(self):
self.cli.connect( (self.HOST.decode(), self.PORT) )
self.cli.settimeout(5)
time.sleep(1)
self.cli.send(MSG.decode(), 0)

def testRecvFrom(self):
# Testing recvfrom() over UDP
msg, addr = self.serv.recvfrom(len(MSG))
self.assertEqual(msg, MSG)
self.assertEqual(type(msg), str)

def _testRecvFrom(self):
self.cli.sendto(MSG.decode(), 0, (self.HOST.decode(), self.PORT))

def testRecvFromTimeoutMode(self):
# Need to test again in timeout mode, which follows
# a different code path
self.serv.settimeout(1)
msg, addr = self.serv.recvfrom(len(MSG))
self.assertEqual(msg.decode(), MSG)

def _testRecvFromTimeoutMode(self):
self.cli.settimeout(1)
self.cli.sendto(MSG.decode(), 0, (self.HOST.decode(), self.PORT))


def test_main():
test_support.run_unittest(
SocketConnectTest,
SSLSocketConnectTest,
SocketOptionsTest,
TimedBasicTCPTest,
BasicTCPUnicodeTest,
BasicUDPUnicodeTest,
)


Expand Down
3 changes: 2 additions & 1 deletion NEWS
Expand Up @@ -19,7 +19,8 @@ New Features


Jython 2.7.4a1 Bugs fixed
- [ GH-269 ] Upgrade Google Guava to 32.0.1 (CVE-2023-2976)
- [ GH-269 ] Upgrade Google Guava to 32.0.1 (CVE-2023-2976)
- [ GH-254 ] Regression in socket.socket.sendall for sending Unicode


==============================================================================
Expand Down

0 comments on commit c020c5e

Please sign in to comment.