diff --git a/hazelcast/serialization/portable/classdef.py b/hazelcast/serialization/portable/classdef.py index bc418bf8b9..4b5446f18c 100644 --- a/hazelcast/serialization/portable/classdef.py +++ b/hazelcast/serialization/portable/classdef.py @@ -225,7 +225,8 @@ def build(self): def _add_field_by_type(self, field_name, field_type, version, factory_id=0, class_id=0): self._check() - self._field_defs.append(FieldDefinition(self._index, field_name, field_type, version, factory_id, class_id)) + fd = FieldDefinition(self._index, field_name, field_type, version, factory_id, class_id) + self._field_defs.append(fd) self._index += 1 def _check(self): diff --git a/hazelcast/serialization/service.py b/hazelcast/serialization/service.py index 0cdd96b53b..ff959967e3 100644 --- a/hazelcast/serialization/service.py +++ b/hazelcast/serialization/service.py @@ -78,24 +78,40 @@ def _register_constant_serializers(self): self._registry.safe_register_serializer(self._registry._python_serializer) def register_class_definitions(self, class_definitions, check_error): - class_defs = dict() + factories = dict() for cd in class_definitions: - if cd in class_defs: - raise HazelcastSerializationError("Duplicate registration found for class-id: %s" % cd.class_id) - class_defs[cd.class_id] = cd + factory_id = cd.factory_id + class_defs = factories.get(factory_id, None) + if class_defs is None: + class_defs = dict() + factories[factory_id] = class_defs + + class_id = cd.class_id + if class_id in class_defs: + raise HazelcastSerializationError("Duplicate registration found for class-id: %s" % class_id) + class_defs[class_id] = cd + for cd in class_definitions: - self.register_class_definition(cd, class_defs, check_error) + self.register_class_definition(cd, factories, check_error) - def register_class_definition(self, cd, class_defs, check_error): - field_names = cd.get_field_names() + def register_class_definition(self, class_definition, factories, check_error): + field_names = class_definition.get_field_names() for field_name in field_names: - fd = cd.get_field(field_name) + fd = class_definition.get_field(field_name) if fd.field_type == FieldType.PORTABLE or fd.field_type == FieldType.PORTABLE_ARRAY: - nested_cd = class_defs.get(fd.class_id, None) - if nested_cd is not None: - self.register_class_definition(nested_cd, class_defs, check_error) - self._portable_context.register_class_definition(nested_cd) - elif check_error: + factory_id = fd.factory_id + class_id = fd.class_id + class_defs = factories.get(factory_id, None) + if class_defs is not None: + nested_cd = class_defs.get(class_id, None) + if nested_cd is not None: + self.register_class_definition(nested_cd, factories, check_error) + self._portable_context.register_class_definition(nested_cd) + continue + + if check_error: raise HazelcastSerializationError( - "Could not find registered ClassDefinition for class-id: %s" % fd.class_id) - self._portable_context.register_class_definition(cd) + "Could not find registered ClassDefinition for factory-id: %s, class-id: %s" + % (factory_id, class_id)) + + self._portable_context.register_class_definition(class_definition) diff --git a/tests/serialization/portable_test.py b/tests/serialization/portable_test.py index ba121eed2a..075a1712a1 100644 --- a/tests/serialization/portable_test.py +++ b/tests/serialization/portable_test.py @@ -231,6 +231,51 @@ def create_portable(): the_factory = {SerializationV1Portable.CLASS_ID: SerializationV1Portable, InnerPortable.CLASS_ID: InnerPortable} +class MyPortable1(Portable): + def __init__(self, str_field=None): + self.str_field = str_field + + def write_portable(self, writer): + writer.write_utf("str_field", self.str_field) + + def read_portable(self, reader): + self.str_field = reader.read_utf("str_field") + + def get_factory_id(self): + return 1 + + def get_class_id(self): + return 1 + + def __eq__(self, other): + return isinstance(other, MyPortable1) and self.str_field == other.str_field + + def __ne__(self, other): + return not self.__eq__(other) + + +class MyPortable2(Portable): + def __init__(self, int_field=0): + self.int_field = int_field + + def write_portable(self, writer): + writer.write_int("int_field", self.int_field) + + def read_portable(self, reader): + self.int_field = reader.read_int("int_field") + + def get_factory_id(self): + return 2 + + def get_class_id(self): + return 1 + + def __eq__(self, other): + return isinstance(other, MyPortable2) and self.int_field == other.int_field + + def __ne__(self, other): + return not self.__eq__(other) + class PortableSerializationTestCase(unittest.TestCase): def test_encode_decode(self): config = _Config() @@ -370,3 +415,61 @@ def test_nested_portable_serialization(self): data = ss1.to_data(p) self.assertEqual(p, ss2.to_object(data)) + + def test_nested_null_portable_serialization(self): + config = _Config() + + config.portable_factories = { + 1: { + 1: Parent, + 2: Child + } + } + + child_class_def = ClassDefinitionBuilder(FACTORY_ID, 2).add_utf_field("name").build() + parent_class_def = ClassDefinitionBuilder(FACTORY_ID, 1).add_portable_field("child", child_class_def).build() + + config.class_definitions = [child_class_def, parent_class_def] + + ss = SerializationServiceV1(config) + + p = Parent(None) + data = ss.to_data(p) + + self.assertEqual(p, ss.to_object(data)) + + def test_duplicate_class_definition(self): + config = _Config() + + class_def1 = ClassDefinitionBuilder(1, 1).add_utf_field("str_field").build() + class_def2 = ClassDefinitionBuilder(1, 1).add_int_field("int_field").build() + + config.class_definitions = [class_def1, class_def2] + + with self.assertRaises(HazelcastSerializationError): + SerializationServiceV1(config) + + def test_classes_with_same_class_id_in_different_factories(self): + config = _Config() + config.portable_factories = { + 1: { + 1: MyPortable1 + }, + 2: { + 1: MyPortable2 + } + } + + class_def1 = ClassDefinitionBuilder(1, 1).add_utf_field("str_field").build() + class_def2 = ClassDefinitionBuilder(2, 1).add_int_field("int_field").build() + + config.class_definitions = [class_def1, class_def2] + ss = SerializationServiceV1(config) + + portable1 = MyPortable1("test") + data1 = ss.to_data(portable1) + self.assertEqual(portable1, ss.to_object(data1)) + + portable2 = MyPortable2(1) + data2 = ss.to_data(portable2) + self.assertEqual(portable2, ss.to_object(data2))