Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hazelcast/serialization/portable/classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ def build(self):

def _add_field_by_type(self, field_name, field_type, version, factory_id=0, class_id=0):
self._check()
self._field_defs.append(FieldDefinition(self._index, field_name, field_type, version, factory_id, class_id))
fd = FieldDefinition(self._index, field_name, field_type, version, factory_id, class_id)
self._field_defs.append(fd)
self._index += 1

def _check(self):
Expand Down
46 changes: 31 additions & 15 deletions hazelcast/serialization/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,40 @@ def _register_constant_serializers(self):
self._registry.safe_register_serializer(self._registry._python_serializer)

def register_class_definitions(self, class_definitions, check_error):
class_defs = dict()
factories = dict()
for cd in class_definitions:
if cd in class_defs:
raise HazelcastSerializationError("Duplicate registration found for class-id: %s" % cd.class_id)
class_defs[cd.class_id] = cd
factory_id = cd.factory_id
class_defs = factories.get(factory_id, None)
if class_defs is None:
class_defs = dict()
factories[factory_id] = class_defs

class_id = cd.class_id
if class_id in class_defs:
raise HazelcastSerializationError("Duplicate registration found for class-id: %s" % class_id)
class_defs[class_id] = cd

for cd in class_definitions:
self.register_class_definition(cd, class_defs, check_error)
self.register_class_definition(cd, factories, check_error)

def register_class_definition(self, cd, class_defs, check_error):
field_names = cd.get_field_names()
def register_class_definition(self, class_definition, factories, check_error):
field_names = class_definition.get_field_names()
for field_name in field_names:
fd = cd.get_field(field_name)
fd = class_definition.get_field(field_name)
if fd.field_type == FieldType.PORTABLE or fd.field_type == FieldType.PORTABLE_ARRAY:
nested_cd = class_defs.get(fd.class_id, None)
if nested_cd is not None:
self.register_class_definition(nested_cd, class_defs, check_error)
self._portable_context.register_class_definition(nested_cd)
elif check_error:
factory_id = fd.factory_id
class_id = fd.class_id
class_defs = factories.get(factory_id, None)
if class_defs is not None:
nested_cd = class_defs.get(class_id, None)
if nested_cd is not None:
self.register_class_definition(nested_cd, factories, check_error)
self._portable_context.register_class_definition(nested_cd)
continue

if check_error:
raise HazelcastSerializationError(
"Could not find registered ClassDefinition for class-id: %s" % fd.class_id)
self._portable_context.register_class_definition(cd)
"Could not find registered ClassDefinition for factory-id: %s, class-id: %s"
% (factory_id, class_id))

self._portable_context.register_class_definition(class_definition)
103 changes: 103 additions & 0 deletions tests/serialization/portable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,51 @@ def create_portable():
the_factory = {SerializationV1Portable.CLASS_ID: SerializationV1Portable, InnerPortable.CLASS_ID: InnerPortable}


class MyPortable1(Portable):
def __init__(self, str_field=None):
self.str_field = str_field

def write_portable(self, writer):
writer.write_utf("str_field", self.str_field)

def read_portable(self, reader):
self.str_field = reader.read_utf("str_field")

def get_factory_id(self):
return 1

def get_class_id(self):
return 1

def __eq__(self, other):
return isinstance(other, MyPortable1) and self.str_field == other.str_field

def __ne__(self, other):
return not self.__eq__(other)


class MyPortable2(Portable):
def __init__(self, int_field=0):
self.int_field = int_field

def write_portable(self, writer):
writer.write_int("int_field", self.int_field)

def read_portable(self, reader):
self.int_field = reader.read_int("int_field")

def get_factory_id(self):
return 2

def get_class_id(self):
return 1

def __eq__(self, other):
return isinstance(other, MyPortable2) and self.int_field == other.int_field

def __ne__(self, other):
return not self.__eq__(other)

class PortableSerializationTestCase(unittest.TestCase):
def test_encode_decode(self):
config = _Config()
Expand Down Expand Up @@ -370,3 +415,61 @@ def test_nested_portable_serialization(self):
data = ss1.to_data(p)

self.assertEqual(p, ss2.to_object(data))

def test_nested_null_portable_serialization(self):
config = _Config()

config.portable_factories = {
1: {
1: Parent,
2: Child
}
}

child_class_def = ClassDefinitionBuilder(FACTORY_ID, 2).add_utf_field("name").build()
parent_class_def = ClassDefinitionBuilder(FACTORY_ID, 1).add_portable_field("child", child_class_def).build()

config.class_definitions = [child_class_def, parent_class_def]

ss = SerializationServiceV1(config)

p = Parent(None)
data = ss.to_data(p)

self.assertEqual(p, ss.to_object(data))

def test_duplicate_class_definition(self):
config = _Config()

class_def1 = ClassDefinitionBuilder(1, 1).add_utf_field("str_field").build()
class_def2 = ClassDefinitionBuilder(1, 1).add_int_field("int_field").build()

config.class_definitions = [class_def1, class_def2]

with self.assertRaises(HazelcastSerializationError):
SerializationServiceV1(config)

def test_classes_with_same_class_id_in_different_factories(self):
config = _Config()
config.portable_factories = {
1: {
1: MyPortable1
},
2: {
1: MyPortable2
}
}

class_def1 = ClassDefinitionBuilder(1, 1).add_utf_field("str_field").build()
class_def2 = ClassDefinitionBuilder(2, 1).add_int_field("int_field").build()

config.class_definitions = [class_def1, class_def2]
ss = SerializationServiceV1(config)

portable1 = MyPortable1("test")
data1 = ss.to_data(portable1)
self.assertEqual(portable1, ss.to_object(data1))

portable2 = MyPortable2(1)
data2 = ss.to_data(portable2)
self.assertEqual(portable2, ss.to_object(data2))