Permalink
Browse files

Added support for Python 3.x -- test suite passes on Python 3.1 and

2.6
  • Loading branch information...
Alex Grönholm
Alex Grönholm committed Mar 8, 2011
1 parent f121918 commit 36a145671501ed47bc4002af7cab49b490eb6e0b
@@ -0,0 +1,25 @@
#coding: utf-8
'''
Compatibility functions for unified behavior between Python 2.x and 3.x.
:author: Alex Grönholm
'''
import sys
if sys.version_info[0] < 3:
items = lambda d: d.items()
iteritems = lambda d: d.iteritems()
next = lambda x: x.next()
range = xrange
long = long
basestring = basestring
unicode = unicode
else:
items = lambda d: list(d.items())
iteritems = lambda d: d.items()
next = next
range = range
long = int
basestring = str
unicode = str
@@ -10,6 +10,8 @@
from threading import RLock
from py4j.compat import items
class ThreadSafeFinalizer(object):
"""A `ThreadSafeFinalizer` is a global class used to register weak reference finalizers
@@ -57,7 +59,7 @@ def clear_finalizers(cls, clear_all=False):
if clear_all:
cls.finalizers.clear()
else:
for id, ref in cls.finalizers.items():
for id, ref in items(cls.finalizers):
if ref() is None:
cls.finalizers.pop(id, None)
@@ -104,7 +106,7 @@ def clear_finalizers(cls, clear_all=False):
if clear_all:
cls.finalizers.clear()
else:
for id, ref in cls.finalizers.items():
for id, ref in items(cls.finalizers):
if ref() is None:
cls.finalizers.pop(id, None)
@@ -9,9 +9,11 @@
:author: Barthelemy Dagenais
'''
from collections import MutableMapping, Sequence, MutableSequence, MutableSet, Set
import sys
from py4j.java_gateway import JavaObject, JavaMember, get_method, JavaClass
from py4j.protocol import *
from py4j.compat import iteritems, next
class JavaIterator(JavaObject):
@@ -38,6 +40,8 @@ def next(self):
return self._methods[self._next_name]()
except Py4JError:
raise StopIteration()
__next__ = next
class JavaMap(JavaObject, MutableMapping):
@@ -70,24 +74,31 @@ def __contains__(self, key):
def __str__(self):
return self.__repr__()
def __repr__(self):
# TODO Make it more efficient/pythonic
# TODO Debug why strings are not outputed with apostrophes.
if len(self) == 0:
return '{}'
else:
srep = '{'
for key in self:
srep += repr(key) + ': ' + repr(self[key]) + ', '
# def __repr__(self):
# # TODO Make it more efficient/pythonic
# # TODO Debug why strings are not outputed with apostrophes.
# if len(self) == 0:
# return '{}'
# else:
# srep = '{'
# for key in self:
# srep += repr(key) + ': ' + repr(self[key]) + ', '
#
# return srep[:-2] + '}'
return srep[:-2] + '}'
def __repr__(self):
items = ('{0}: {1}'.format(repr(k), repr(v)) for k, v in iteritems(self))
return '{{{0}}}'.format(', '.join(items))
class JavaSet(JavaObject, MutableSet):
"""Maps a Python Set to a Java Set.
All operations possible on a Python set are implemented."""
__EMPTY_SET = 'set([])' if sys.version_info[0] < 3 else 'set()'
__SET_TEMPLATE = 'set([{0}])' if sys.version_info[0] < 3 else '{{{0}}}'
def __init__(self, target_id, gateway_client):
JavaObject.__init__(self, target_id, gateway_client)
self._add = get_method(self, 'add')
@@ -122,14 +133,9 @@ def __str__(self):
return self.__repr__()
def __repr__(self):
if len(self) == 0:
return 'set([])'
else:
srep = 'set(['
for value in self:
srep += repr(value) + ', '
return srep[:-2] + '])'
if len(self):
return self.__SET_TEMPLATE.format(', '.join((repr(x) for x in self)))
return self.__EMPTY_SET
class JavaArray(JavaObject, Sequence):
@@ -182,7 +188,7 @@ def __getitem__(self, key):
def __repl_item_from_slice(self, range, iterable):
value_iter = iter(iterable)
for i in range:
value = value_iter.next()
value = next(value_iter)
self.__set_item(i, value)
def __set_item(self, key, value):
@@ -225,6 +231,9 @@ class JavaList(JavaObject, MutableSequence):
will create a copy of the list on the JVM. Slicing is thus not equivalent to subList(), because
a modification to a slice such as the addition of a new element will not affect the original
list."""
__EMPTY_SET = '[]' if sys.version_info[0] < 3 else '{}'
__REPR_TEMPLATE = 'set([%s])' if sys.version_info[0] < 3 else '{%s}'
def __init__(self, target_id, gateway_client):
JavaObject.__init__(self, target_id, gateway_client)
@@ -263,7 +272,7 @@ def __set_item_from_slice(self, indices, iterable):
# First replace and delete if from_slice > to_slice
for i in range(*indices):
try:
value = value_iter.next()
value = next(value_iter)
self.__set_item(i, value)
except StopIteration:
self.__del_item(i)
@@ -284,7 +293,7 @@ def __insert_item_from_slice(self, indices, iterable):
def __repl_item_from_slice(self, range, iterable):
value_iter = iter(iterable)
for i in range:
value = value_iter.next()
value = value = next(value_iter)
self.__set_item(i, value)
def __append_item_from_slice(self, range, iterable):
@@ -426,14 +435,8 @@ def __str__(self):
return self.__repr__()
def __repr__(self):
if len(self) == 0:
return '[]'
else:
srep = '['
for elem in self:
srep += repr(elem) + ', '
return srep[:-2] + ']'
items = (repr(x) for x in self)
return '[{0}]'.format(', '.join(items))
class SetConverter(object):
@@ -20,6 +20,7 @@
from py4j.finalizer import ThreadSafeFinalizer
from py4j.protocol import *
from py4j.compat import range
class NullHandler(logging.Handler):
@@ -258,7 +259,7 @@ def start(self):
"""Starts the connection by connecting to the `address` and the `port`"""
self.socket.connect((self.address, self.port))
self.is_connected = True
self.stream = self.socket.makefile('r', 0)
self.stream = self.socket.makefile('rb', 0)
def close(self, throw_exception=False):
"""Closes the connection by closing the socket."""
@@ -703,7 +704,7 @@ def run(self):
while not self.is_shutdown:
socket, _ = self.server_socket.accept()
input = socket.makefile('r', 0)
input = socket.makefile('rb', 0)
connection = CallbackConnection(self.pool, input, socket, self.gateway_client)
with self.lock:
if not self.is_shutdown:
@@ -16,6 +16,9 @@
:author: Barthelemy Dagenais
'''
from py4j.compat import long, basestring
ESCAPE_CHAR = "\\"
# Entry point
@@ -3,6 +3,7 @@
@author: Barthelemy Dagenais
'''
from __future__ import unicode_literals
from multiprocessing.process import Process
import subprocess
import time
@@ -46,12 +47,12 @@ def testArray(self):
self.assertEqual(3, len(array1))
self.assertEqual(4, len(array2))
self.assertEqual(u'333', array1[2])
self.assertEqual('333', array1[2])
self.assertEqual(5, array2[1])
array1[2] = 'aaa'
array2[1] = 6
self.assertEqual(u'aaa', array1[2])
self.assertEqual('aaa', array1[2])
self.assertEqual(6, array2[1])
new_array = array2[1:3]
@@ -11,6 +11,7 @@
from py4j.java_gateway import JavaGateway, PythonProxyPool
from py4j.tests.java_gateway_test import PY4J_JAVA_PATH
from py4j.compat import range
def start_example_server():
@@ -83,7 +84,7 @@ class TestPool(unittest.TestCase):
def testPool(self):
pool = PythonProxyPool()
runners = [Runner(xrange(0, 10000), pool) for _ in xrange(0, 3)]
runners = [Runner(range(0, 10000), pool) for _ in range(0, 3)]
for runner in runners:
runner.start()
@@ -3,6 +3,7 @@
@author: barthelemy
'''
from __future__ import unicode_literals
from multiprocessing.process import Process
from socket import AF_INET, SOCK_STREAM, socket
from threading import Thread
@@ -17,6 +18,7 @@
from py4j.protocol import *
from py4j.java_gateway import JavaGateway, JavaMember, get_field, get_method, \
GatewayClient, set_field, java_import, JavaObject
from py4j.compat import range
SERVER_PORT = 25333
@@ -90,7 +92,7 @@ def tearDown(self):
def testEscape(self):
self.assertEqual("Hello\t\rWorld\n\\", unescape_new_line(escape_new_line("Hello\t\rWorld\n\\")))
self.assertEqual(u"Hello\t\rWorld\n\\", unescape_new_line(escape_new_line(u"Hello\t\rWorld\n\\")))
self.assertEqual("Hello\t\rWorld\n\\", unescape_new_line(escape_new_line("Hello\t\rWorld\n\\")))
def testProtocolSend(self):
testConnection = TestConnection()
@@ -179,7 +181,7 @@ def testException(self):
testSocket.sendall('yo\n'.encode('utf-8'))
testSocket.sendall('yro0\n'.encode('utf-8'))
testSocket.sendall('yo\n'.encode('utf-8'))
testSocket.sendall('x\n')
testSocket.sendall(b'x\n')
testSocket.close()
time.sleep(1)
@@ -235,13 +237,13 @@ def testNoneArg(self):
def testUnicode(self):
sb = self.gateway.jvm.java.lang.StringBuffer()
sb.append(u'\r\n\tHello\r\n\t')
self.assertEqual(u'\r\n\tHello\r\n\t', sb.toString())
sb.append('\r\n\tHello\r\n\t')
self.assertEqual('\r\n\tHello\r\n\t', sb.toString())
def testEscape(self):
sb = self.gateway.jvm.java.lang.StringBuffer()
sb.append('\r\n\tHello\r\n\t')
self.assertEqual(u'\r\n\tHello\r\n\t', sb.toString())
self.assertEqual('\r\n\tHello\r\n\t', sb.toString())
class FieldTest(unittest.TestCase):
@@ -260,7 +262,7 @@ def testAutoField(self):
self.assertEqual(ex.field10, 10)
sb = ex.field20
sb.append('Hello')
self.assertEqual(u'Hello', sb.toString())
self.assertEqual('Hello', sb.toString())
self.assertTrue(ex.field21 == None)
def testNoField(self):
@@ -285,7 +287,7 @@ def testNoAutoField(self):
ex._auto_field = True
sb = ex.field20
sb.append('Hello')
self.assertEqual(u'Hello', sb.toString())
self.assertEqual('Hello', sb.toString())
try:
get_field(ex, 'field20')
@@ -302,7 +304,7 @@ def testSetField(self):
sb = self.gateway.jvm.java.lang.StringBuffer('Hello World!')
set_field(ex, 'field21', sb)
self.assertEquals(get_field(ex, 'field21').toString(), u'Hello World!')
self.assertEquals(get_field(ex, 'field21').toString(), 'Hello World!')
try:
set_field(ex, 'field1', 123)
@@ -458,22 +460,22 @@ def testConstructors(self):
sb = jvm.java.lang.StringBuffer('hello')
sb.append('hello world')
sb.append(1)
self.assertEqual(sb.toString(), u'hellohello world1')
self.assertEqual(sb.toString(), 'hellohello world1')
l1 = jvm.java.util.ArrayList()
l1.append('hello world')
l1.append(1)
self.assertEqual(2, len(l1))
self.assertEqual(u'hello world', l1[0])
l2 = [u'hello world', 1]
self.assertEqual('hello world', l1[0])
l2 = ['hello world', 1]
print(l1)
print(l2)
self.assertEqual(str(l2), str(l1))
def testStaticMethods(self):
System = self.gateway.jvm.java.lang.System
self.assertTrue(System.currentTimeMillis() > 0)
self.assertEqual(u'123', self.gateway.jvm.java.lang.String.valueOf(123))
self.assertEqual('123', self.gateway.jvm.java.lang.String.valueOf(123))
def testStaticFields(self):
Short = self.gateway.jvm.java.lang.Short
@@ -483,7 +485,7 @@ def testStaticFields(self):
def testDefaultImports(self):
self.assertTrue(self.gateway.jvm.System.currentTimeMillis() > 0)
self.assertEqual(u'123', self.gateway.jvm.String.valueOf(123))
self.assertEqual('123', self.gateway.jvm.String.valueOf(123))
def testNone(self):
ex = self.gateway.entry_point.getNewExample()
@@ -581,9 +583,9 @@ def testStress(self):
# runner2 = Runner(xrange(1000,1000000,10000), self.gateway)
# runner3 = Runner(xrange(1000,1000000,10000), self.gateway)
# Small stress test
runner1 = Runner(xrange(1, 10000, 1000), self.gateway)
runner2 = Runner(xrange(1000, 1000000, 100000), self.gateway)
runner3 = Runner(xrange(1000, 1000000, 100000), self.gateway)
runner1 = Runner(range(1, 10000, 1000), self.gateway)
runner2 = Runner(range(1000, 1000000, 100000), self.gateway)
runner3 = Runner(range(1000, 1000000, 100000), self.gateway)
runner1.start()
runner2.start()
runner3.start()
Oops, something went wrong.

0 comments on commit 36a1456

Please sign in to comment.