From cc23f6e5a8b9336acf106eeacbb47d2ee9d4ff04 Mon Sep 17 00:00:00 2001 From: mdumandag Date: Mon, 26 Oct 2020 13:26:50 +0300 Subject: [PATCH 1/2] Fix class definition registrations with the same class id When class definitions for two different classes with the same class id but different factory ids are registered, we were throwing an error indicating that there are duplicate registrations. However, we should allow such cases. Apart from the fix for that, the PR also includes a test case for null portable serialization. Closes https://github.com/hazelcast/hazelcast-python-client/issues/199 --- hazelcast/serialization/portable/classdef.py | 3 +- hazelcast/serialization/service.py | 40 ++++--- tests/serialization/portable_test.py | 103 +++++++++++++++++++ 3 files changed, 133 insertions(+), 13 deletions(-) diff --git a/hazelcast/serialization/portable/classdef.py b/hazelcast/serialization/portable/classdef.py index bc418bf8b9..4b5446f18c 100644 --- a/hazelcast/serialization/portable/classdef.py +++ b/hazelcast/serialization/portable/classdef.py @@ -225,7 +225,8 @@ def build(self): def _add_field_by_type(self, field_name, field_type, version, factory_id=0, class_id=0): self._check() - self._field_defs.append(FieldDefinition(self._index, field_name, field_type, version, factory_id, class_id)) + fd = FieldDefinition(self._index, field_name, field_type, version, factory_id, class_id) + self._field_defs.append(fd) self._index += 1 def _check(self): diff --git a/hazelcast/serialization/service.py b/hazelcast/serialization/service.py index 0cdd96b53b..a9613026cc 100644 --- a/hazelcast/serialization/service.py +++ b/hazelcast/serialization/service.py @@ -78,24 +78,40 @@ def _register_constant_serializers(self): self._registry.safe_register_serializer(self._registry._python_serializer) def register_class_definitions(self, class_definitions, check_error): - class_defs = dict() + factories = dict() for cd in class_definitions: - if cd in class_defs: - raise HazelcastSerializationError("Duplicate registration found for class-id: %s" % cd.class_id) - class_defs[cd.class_id] = cd + factory_id = cd.factory_id + class_defs = factories.get(factory_id, None) + if class_defs is None: + class_defs = dict() + factories[factory_id] = class_defs + + class_id = cd.class_id + if class_id in class_defs: + raise HazelcastSerializationError("Duplicate registration found for class-id: %s" % class_id) + class_defs[class_id] = cd + for cd in class_definitions: - self.register_class_definition(cd, class_defs, check_error) + self.register_class_definition(cd, factories, check_error) - def register_class_definition(self, cd, class_defs, check_error): + def register_class_definition(self, cd, factories, check_error): field_names = cd.get_field_names() for field_name in field_names: fd = cd.get_field(field_name) if fd.field_type == FieldType.PORTABLE or fd.field_type == FieldType.PORTABLE_ARRAY: - nested_cd = class_defs.get(fd.class_id, None) - if nested_cd is not None: - self.register_class_definition(nested_cd, class_defs, check_error) - self._portable_context.register_class_definition(nested_cd) - elif check_error: + factory_id = fd.factory_id + class_id = fd.class_id + class_defs = factories.get(factory_id, None) + if class_defs is not None: + nested_cd = class_defs.get(class_id, None) + if nested_cd is not None: + self.register_class_definition(nested_cd, factories, check_error) + self._portable_context.register_class_definition(nested_cd) + continue + + if check_error: raise HazelcastSerializationError( - "Could not find registered ClassDefinition for class-id: %s" % fd.class_id) + "Could not find registered ClassDefinition for factory-id: %s, class-id: %s" + % (factory_id, class_id)) + self._portable_context.register_class_definition(cd) diff --git a/tests/serialization/portable_test.py b/tests/serialization/portable_test.py index ba121eed2a..075a1712a1 100644 --- a/tests/serialization/portable_test.py +++ b/tests/serialization/portable_test.py @@ -231,6 +231,51 @@ def create_portable(): the_factory = {SerializationV1Portable.CLASS_ID: SerializationV1Portable, InnerPortable.CLASS_ID: InnerPortable} +class MyPortable1(Portable): + def __init__(self, str_field=None): + self.str_field = str_field + + def write_portable(self, writer): + writer.write_utf("str_field", self.str_field) + + def read_portable(self, reader): + self.str_field = reader.read_utf("str_field") + + def get_factory_id(self): + return 1 + + def get_class_id(self): + return 1 + + def __eq__(self, other): + return isinstance(other, MyPortable1) and self.str_field == other.str_field + + def __ne__(self, other): + return not self.__eq__(other) + + +class MyPortable2(Portable): + def __init__(self, int_field=0): + self.int_field = int_field + + def write_portable(self, writer): + writer.write_int("int_field", self.int_field) + + def read_portable(self, reader): + self.int_field = reader.read_int("int_field") + + def get_factory_id(self): + return 2 + + def get_class_id(self): + return 1 + + def __eq__(self, other): + return isinstance(other, MyPortable2) and self.int_field == other.int_field + + def __ne__(self, other): + return not self.__eq__(other) + class PortableSerializationTestCase(unittest.TestCase): def test_encode_decode(self): config = _Config() @@ -370,3 +415,61 @@ def test_nested_portable_serialization(self): data = ss1.to_data(p) self.assertEqual(p, ss2.to_object(data)) + + def test_nested_null_portable_serialization(self): + config = _Config() + + config.portable_factories = { + 1: { + 1: Parent, + 2: Child + } + } + + child_class_def = ClassDefinitionBuilder(FACTORY_ID, 2).add_utf_field("name").build() + parent_class_def = ClassDefinitionBuilder(FACTORY_ID, 1).add_portable_field("child", child_class_def).build() + + config.class_definitions = [child_class_def, parent_class_def] + + ss = SerializationServiceV1(config) + + p = Parent(None) + data = ss.to_data(p) + + self.assertEqual(p, ss.to_object(data)) + + def test_duplicate_class_definition(self): + config = _Config() + + class_def1 = ClassDefinitionBuilder(1, 1).add_utf_field("str_field").build() + class_def2 = ClassDefinitionBuilder(1, 1).add_int_field("int_field").build() + + config.class_definitions = [class_def1, class_def2] + + with self.assertRaises(HazelcastSerializationError): + SerializationServiceV1(config) + + def test_classes_with_same_class_id_in_different_factories(self): + config = _Config() + config.portable_factories = { + 1: { + 1: MyPortable1 + }, + 2: { + 1: MyPortable2 + } + } + + class_def1 = ClassDefinitionBuilder(1, 1).add_utf_field("str_field").build() + class_def2 = ClassDefinitionBuilder(2, 1).add_int_field("int_field").build() + + config.class_definitions = [class_def1, class_def2] + ss = SerializationServiceV1(config) + + portable1 = MyPortable1("test") + data1 = ss.to_data(portable1) + self.assertEqual(portable1, ss.to_object(data1)) + + portable2 = MyPortable2(1) + data2 = ss.to_data(portable2) + self.assertEqual(portable2, ss.to_object(data2)) From 24de635911a8b160d45d59655fe9eb6fb9630b0d Mon Sep 17 00:00:00 2001 From: mdumandag Date: Mon, 9 Nov 2020 10:17:26 +0300 Subject: [PATCH 2/2] rename parameter from cd to class_definition --- hazelcast/serialization/service.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hazelcast/serialization/service.py b/hazelcast/serialization/service.py index a9613026cc..ff959967e3 100644 --- a/hazelcast/serialization/service.py +++ b/hazelcast/serialization/service.py @@ -94,10 +94,10 @@ def register_class_definitions(self, class_definitions, check_error): for cd in class_definitions: self.register_class_definition(cd, factories, check_error) - def register_class_definition(self, cd, factories, check_error): - field_names = cd.get_field_names() + def register_class_definition(self, class_definition, factories, check_error): + field_names = class_definition.get_field_names() for field_name in field_names: - fd = cd.get_field(field_name) + fd = class_definition.get_field(field_name) if fd.field_type == FieldType.PORTABLE or fd.field_type == FieldType.PORTABLE_ARRAY: factory_id = fd.factory_id class_id = fd.class_id @@ -114,4 +114,4 @@ def register_class_definition(self, cd, factories, check_error): "Could not find registered ClassDefinition for factory-id: %s, class-id: %s" % (factory_id, class_id)) - self._portable_context.register_class_definition(cd) + self._portable_context.register_class_definition(class_definition)