diff --git a/hazelcast/serialization/portable/classdef.py b/hazelcast/serialization/portable/classdef.py index 9c99a95d5a..a3c66e00c6 100644 --- a/hazelcast/serialization/portable/classdef.py +++ b/hazelcast/serialization/portable/classdef.py @@ -27,22 +27,26 @@ class FieldDefinition(object): - def __init__(self, index, field_name, field_type, factory_id=0, class_id=0): + def __init__(self, index, field_name, field_type, version, factory_id=0, class_id=0): self.index = index self.field_name = field_name self.field_type = field_type + self.version = version self.factory_id = factory_id self.class_id = class_id def __eq__(self, other): return isinstance(other, self.__class__) \ - and (self.index, self.field_name, self.field_type, self.factory_id, self.class_id) == \ - (other.index, other.field_name, other.field_type, other.factory_id, other.class_id) + and (self.index, self.field_name, self.field_type, self.version, self.factory_id, self.class_id) == \ + (other.index, other.field_name, other.field_type, other.version, other.factory_id, other.class_id) def __repr__(self): - return "FieldDefinition[ ix:{}, name:{}, type:{}, fid:{}, cid:{}]".format(self.index, self.field_name, self.field_type, - self.factory_id, - self.class_id) + return "FieldDefinition[ ix:{}, name:{}, type:{}, version:{}, fid:{}, cid:{}]".format(self.index, + self.field_name, + self.field_type, + self.version, + self.factory_id, + self.class_id) class ClassDefinition(object): @@ -94,7 +98,7 @@ def set_version_if_not_set(self, version): def __eq__(self, other): return isinstance(other, self.__class__) and (self.factory_id, self.class_id, self.version, self.field_defs) == \ - (other.factory_id, other.class_id, other.version, self.field_defs) + (other.factory_id, other.class_id, other.version, other.field_defs) def __ne__(self, other): return not self.__eq__(other) @@ -105,8 +109,9 @@ def __repr__(self): def __hash__(self): return id(self)//16 + class ClassDefinitionBuilder(object): - def __init__(self, factory_id, class_id, version=-1): + def __init__(self, factory_id, class_id, version=0): self.factory_id = factory_id self.class_id = class_id self.version = version @@ -118,85 +123,87 @@ def __init__(self, factory_id, class_id, version=-1): def add_portable_field(self, field_name, class_def): if class_def.class_id is None or class_def.class_id == 0: raise ValueError("Portable class id cannot be zero!") - self._add_field_by_type(field_name, FieldType.PORTABLE, class_def.factory_id, class_def.class_id) + self._add_field_by_type(field_name, FieldType.PORTABLE, class_def.version, + class_def.factory_id, class_def.class_id) return self def add_byte_field(self, field_name): - self._add_field_by_type(field_name, FieldType.BYTE) + self._add_field_by_type(field_name, FieldType.BYTE, self.version) return self def add_boolean_field(self, field_name): - self._add_field_by_type(field_name, FieldType.BOOLEAN) + self._add_field_by_type(field_name, FieldType.BOOLEAN, self.version) return self def add_char_field(self, field_name): - self._add_field_by_type(field_name, FieldType.CHAR) + self._add_field_by_type(field_name, FieldType.CHAR, self.version) return self def add_short_field(self, field_name): - self._add_field_by_type(field_name, FieldType.SHORT) + self._add_field_by_type(field_name, FieldType.SHORT, self.version) return self def add_int_field(self, field_name): - self._add_field_by_type(field_name, FieldType.INT) + self._add_field_by_type(field_name, FieldType.INT, self.version) return self def add_long_field(self, field_name): - self._add_field_by_type(field_name, FieldType.LONG) + self._add_field_by_type(field_name, FieldType.LONG, self.version) return self def add_float_field(self, field_name): - self._add_field_by_type(field_name, FieldType.FLOAT) + self._add_field_by_type(field_name, FieldType.FLOAT, self.version) return self def add_double_field(self, field_name): - self._add_field_by_type(field_name, FieldType.DOUBLE) + self._add_field_by_type(field_name, FieldType.DOUBLE, self.version) return self def add_utf_field(self, field_name): - self._add_field_by_type(field_name, FieldType.UTF) + self._add_field_by_type(field_name, FieldType.UTF, self.version) return self def add_portable_array_field(self, field_name, class_def): if class_def.class_id is None or class_def.class_id == 0: raise ValueError("Portable class id cannot be zero!") - self._add_field_by_type(field_name, FieldType.PORTABLE_ARRAY, class_def.factory_id, class_def.class_id) + self._add_field_by_type(field_name, FieldType.PORTABLE_ARRAY, class_def.version, + class_def.factory_id, class_def.class_id) return self def add_byte_array_field(self, field_name): - self._add_field_by_type(field_name, FieldType.BYTE_ARRAY) + self._add_field_by_type(field_name, FieldType.BYTE_ARRAY, self.version) return self def add_boolean_array_field(self, field_name): - self._add_field_by_type(field_name, FieldType.BOOLEAN_ARRAY) + self._add_field_by_type(field_name, FieldType.BOOLEAN_ARRAY, self.version) return self def add_char_array_field(self, field_name): - self._add_field_by_type(field_name, FieldType.CHAR_ARRAY) + self._add_field_by_type(field_name, FieldType.CHAR_ARRAY, self.version) return self def add_short_array_field(self, field_name): - self._add_field_by_type(field_name, FieldType.SHORT_ARRAY) + self._add_field_by_type(field_name, FieldType.SHORT_ARRAY, self.version) return self def add_int_array_field(self, field_name): - self._add_field_by_type(field_name, FieldType.INT_ARRAY) + self._add_field_by_type(field_name, FieldType.INT_ARRAY, self.version) return self def add_long_array_field(self, field_name): - self._add_field_by_type(field_name, FieldType.LONG_ARRAY) + self._add_field_by_type(field_name, FieldType.LONG_ARRAY, self.version) return self def add_float_array_field(self, field_name): - self._add_field_by_type(field_name, FieldType.FLOAT_ARRAY) + self._add_field_by_type(field_name, FieldType.FLOAT_ARRAY, self.version) return self def add_double_array_field(self, field_name): - self._add_field_by_type(field_name, FieldType.DOUBLE_ARRAY) + self._add_field_by_type(field_name, FieldType.DOUBLE_ARRAY, self.version) return self def add_utf_array_field(self, field_name): - self._add_field_by_type(field_name, FieldType.UTF_ARRAY) + self._add_field_by_type(field_name, FieldType.UTF_ARRAY, self.version) return self def add_field_def(self, field_def): @@ -214,9 +221,9 @@ def build(self): cd.add_field_def(field_def) return cd - def _add_field_by_type(self, field_name, field_type, factory_id=0, class_id=0): + def _add_field_by_type(self, field_name, field_type, version, factory_id=0, class_id=0): self._check() - self._field_defs.append(FieldDefinition(self._index, field_name, field_type, factory_id, class_id)) + self._field_defs.append(FieldDefinition(self._index, field_name, field_type, version, factory_id, class_id)) self._index += 1 def _check(self): diff --git a/hazelcast/serialization/portable/context.py b/hazelcast/serialization/portable/context.py index a6e8839d75..3e0a41736e 100644 --- a/hazelcast/serialization/portable/context.py +++ b/hazelcast/serialization/portable/context.py @@ -45,6 +45,7 @@ def read_class_definition(self, data_in, factory_id, class_id, version): field_factory_id = 0 field_class_id = 0 + field_version = version if field_type == FieldType.PORTABLE: # is null if data_in.read_boolean(): @@ -69,7 +70,8 @@ def read_class_definition(self, data_in, factory_id, class_id, version): self.read_class_definition(data_in, field_factory_id, field_class_id, field_version) else: register = False - builder.add_field_def(FieldDefinition(i, field_name.decode('ascii'), field_type, field_factory_id, field_class_id)) + builder.add_field_def(FieldDefinition(i, field_name.decode('ascii'), field_type, field_version, + field_factory_id, field_class_id)) class_def = builder.build() if register: class_def = self.register_class_definition(class_def) diff --git a/hazelcast/serialization/predicate.py b/hazelcast/serialization/predicate.py index 15a6415dc4..b83515c6fc 100644 --- a/hazelcast/serialization/predicate.py +++ b/hazelcast/serialization/predicate.py @@ -220,6 +220,7 @@ def __repr__(self): false = FalsePredicate true = TruePredicate + def is_greater_than(attribute, x): return GreaterLessPredicate(attribute, x, False, False) diff --git a/hazelcast/serialization/service.py b/hazelcast/serialization/service.py index 7754b03649..fb7a31774f 100644 --- a/hazelcast/serialization/service.py +++ b/hazelcast/serialization/service.py @@ -20,13 +20,12 @@ def default_partition_strategy(key): class SerializationServiceV1(BaseSerializationService): logger = logging.getLogger("SerializationService") - def __init__(self, serialization_config, version=1, portable_version=0, - global_partition_strategy=default_partition_strategy, + def __init__(self, serialization_config, version=1, global_partition_strategy=default_partition_strategy, output_buffer_size=DEFAULT_OUT_BUFFER_SIZE): super(SerializationServiceV1, self).__init__(version, global_partition_strategy, output_buffer_size, serialization_config.is_big_endian, serialization_config.default_integer_type) - self._portable_context = PortableContext(self, portable_version) + self._portable_context = PortableContext(self, serialization_config.portable_version) self.register_class_definitions(serialization_config.class_definitions, serialization_config.check_class_def_errors) self._registry._portable_serializer = PortableSerializer(self._portable_context, serialization_config.portable_factories) diff --git a/tests/predicate_test.py b/tests/predicate_test.py index 082cc95114..da902e1f93 100644 --- a/tests/predicate_test.py +++ b/tests/predicate_test.py @@ -3,6 +3,7 @@ from hazelcast.serialization.predicate import is_equal_to, and_, is_between, is_less_than, \ is_less_than_or_equal_to, is_greater_than, is_greater_than_or_equal_to, or_, is_not_equal_to, not_, is_like, \ is_ilike, matches_regex, sql, true, false, is_in, is_instance_of +from hazelcast.serialization.api import Portable from tests.base import SingleMemberTestCase from tests.serialization.portable_test import InnerPortable, FACTORY_ID from tests.util import random_string @@ -237,3 +238,88 @@ def test_predicate_portable_key(self): for k in key_set: self.assertGreaterEqual(k.param_int, 900) self.assertIn(k, map_keys) + + +class NestedPredicatePortableTest(SingleMemberTestCase): + + class Body(Portable): + def __init__(self, name=None, limb=None): + self.name = name + self.limb = limb + + def get_class_id(self): + return 1 + + def get_factory_id(self): + return 1 + + def get_class_version(self): + return 15 + + def write_portable(self, writer): + writer.write_utf("name", self.name) + writer.write_portable("limb", self.limb) + + def read_portable(self, reader): + self.name = reader.read_utf("name") + self.limb = reader.read_portable("limb") + + def __eq__(self, other): + return isinstance(other, self.__class__) and (self.name, self.limb) == (other.name, other.limb) + + class Limb(Portable): + def __init__(self, name=None): + self.name = name + + def get_class_id(self): + return 2 + + def get_factory_id(self): + return 1 + + def get_class_version(self): + return 2 + + def write_portable(self, writer): + writer.write_utf("name", self.name) + + def read_portable(self, reader): + self.name = reader.read_utf("name") + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.name == other.name + + @classmethod + def configure_client(cls, config): + factory = {1: NestedPredicatePortableTest.Body, 2: NestedPredicatePortableTest.Limb} + config.serialization_config.portable_factories[FACTORY_ID] = factory + return config + + def setUp(self): + self.map = self.client.get_map(random_string()).blocking() + self.map.put(1, NestedPredicatePortableTest.Body("body1", NestedPredicatePortableTest.Limb("hand"))) + self.map.put(2, NestedPredicatePortableTest.Body("body2", NestedPredicatePortableTest.Limb("leg"))) + + def tearDown(self): + self.map.destroy() + + def test_adding_indexes(self): + # single-attribute index + self.map.add_index("name", True) + + # nested-attribute index + self.map.add_index("limb.name", True) + + def test_single_attribute_query_portable_predicates(self): + predicate = is_equal_to("limb.name", "hand") + values = self.map.values(predicate) + + self.assertEqual(1, len(values)) + self.assertEqual("body1", values[0].name) + + def test_nested_attribute_query_sql_predicate(self): + predicate = sql("limb.name == 'leg'") + values = self.map.values(predicate) + + self.assertEqual(1, len(values)) + self.assertEqual("body2", values[0].name) diff --git a/tests/serialization/portable_test.py b/tests/serialization/portable_test.py index 8cfa4a0c5b..b47310755a 100644 --- a/tests/serialization/portable_test.py +++ b/tests/serialization/portable_test.py @@ -175,6 +175,46 @@ def __hash__(self): return id(self)//16 +class Parent(Portable): + def __init__(self, child=None): + self.child = child + + def get_class_id(self): + return 1 + + def get_factory_id(self): + return 1 + + def write_portable(self, writer): + writer.write_portable("child", self.child) + + def read_portable(self, reader): + self.child = reader.read_portable("child") + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.child == other.child + + +class Child(Portable): + def __init__(self, name=None): + self.name = name + + def get_factory_id(self): + return 1 + + def get_class_id(self): + return 2 + + def write_portable(self, writer): + writer.write_utf("name", self.name) + + def read_portable(self, reader): + self.name = reader.read_utf("name") + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.name == other.name + + def create_portable(): identified = create_identified() inner_portable = InnerPortable("Inner Text", 666) @@ -296,3 +336,20 @@ def test_portable_read_without_factory(self): data = service.to_data(obj) with self.assertRaises(HazelcastSerializationError): service2.to_object(data) + + def test_nested_portable_serialization(self): + serialization_config = hazelcast.SerializationConfig() + serialization_config.portable_version = 6 + + serialization_config.portable_factories[1] = {1: Parent, 2: Child} + + ss1 = SerializationServiceV1(serialization_config) + ss2 = SerializationServiceV1(serialization_config) + + ss2.to_data(Child("Joe")) + + p = Parent(Child("Joe")) + + data = ss1.to_data(p) + + self.assertEqual(p, ss2.to_object(data))