Skip to content

Commit

Permalink
Add ObjectListBase concat methods
Browse files Browse the repository at this point in the history
This adds support for using + to concatenate two objects derived from
ObjectListBase. The original objects are not modified, a new list object
is returned.

The requirements for this are that the two objects are
of the same class, and there is only an 'objects' field.

Change-Id: I4610db6b52f3f576a6d0c2e64af39927077586cd
  • Loading branch information
alaski committed Sep 15, 2016
1 parent 020dd5d commit f351948
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
22 changes: 22 additions & 0 deletions oslo_versionedobjects/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,28 @@ def obj_what_changed(self):
changes.add('objects')
return changes

def __add__(self, other):
# Handling arbitrary fields may not make sense if those fields are not
# all concatenatable. Only concatenate if the base 'objects' field is
# the only one and the classes match.
if (self.__class__ == other.__class__ and
list(self.__class__.fields.keys()) == ['objects']):
return self.__class__(objects=self.objects + other.objects)
else:
raise TypeError("List Objects should be of the same type and only "
"have an 'objects' field")

def __radd__(self, other):
if (self.__class__ == other.__class__ and
list(self.__class__.fields.keys()) == ['objects']):
# This should never be run in practice. If the above condition is
# met then __add__ would have been run.
raise NotImplementedError('__radd__ is not implemented for '
'objects of the same type')
else:
raise TypeError("List Objects should be of the same type and only "
"have an 'objects' field")


class VersionedObjectSerializer(messaging.NoOpSerializer):
"""A VersionedObject-aware Serializer.
Expand Down
89 changes: 89 additions & 0 deletions oslo_versionedobjects/tests/test_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,3 +2251,92 @@ class TestObject(base.VersionedObject):
'TestChild': '2.34',
'TestChildTwo': '4.56'},
tree)


class TestListObjectConcat(test.TestCase):
def test_list_object_concat(self):
@base.VersionedObjectRegistry.register_if(False)
class MyList(base.ObjectListBase, base.VersionedObject):
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}

values = [1, 2, 42]

list1 = MyList(objects=[MyOwnedObject(baz=values[0]),
MyOwnedObject(baz=values[1])])
list2 = MyList(objects=[MyOwnedObject(baz=values[2])])

concat_list = list1 + list2
for idx, obj in enumerate(concat_list):
self.assertEqual(values[idx], obj.baz)

# Assert that the original lists are unmodified
self.assertEqual(2, len(list1.objects))
self.assertEqual(1, list1.objects[0].baz)
self.assertEqual(2, list1.objects[1].baz)
self.assertEqual(1, len(list2.objects))
self.assertEqual(42, list2.objects[0].baz)

def test_list_object_concat_fails_different_objects(self):
@base.VersionedObjectRegistry.register_if(False)
class MyList(base.ObjectListBase, base.VersionedObject):
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}

@base.VersionedObjectRegistry.register_if(False)
class MyList2(base.ObjectListBase, base.VersionedObject):
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}

list1 = MyList(objects=[MyOwnedObject(baz=1)])
list2 = MyList2(objects=[MyOwnedObject(baz=2)])

def add(x, y):
return x + y

self.assertRaises(TypeError, add, list1, list2)
# Assert that the original lists are unmodified
self.assertEqual(1, len(list1.objects))
self.assertEqual(1, len(list2.objects))
self.assertEqual(1, list1.objects[0].baz)
self.assertEqual(2, list2.objects[0].baz)

def test_list_object_concat_fails_extra_fields(self):
@base.VersionedObjectRegistry.register_if(False)
class MyList(base.ObjectListBase, base.VersionedObject):
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject'),
'foo': fields.IntegerField(nullable=True)}

list1 = MyList(objects=[MyOwnedObject(baz=1)])
list2 = MyList(objects=[MyOwnedObject(baz=2)])

def add(x, y):
return x + y

self.assertRaises(TypeError, add, list1, list2)
# Assert that the original lists are unmodified
self.assertEqual(1, len(list1.objects))
self.assertEqual(1, len(list2.objects))
self.assertEqual(1, list1.objects[0].baz)
self.assertEqual(2, list2.objects[0].baz)

def test_builtin_list_add_fails(self):
@base.VersionedObjectRegistry.register_if(False)
class MyList(base.ObjectListBase, base.VersionedObject):
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}

list1 = MyList(objects=[MyOwnedObject(baz=1)])

def add(obj):
return obj + []

self.assertRaises(TypeError, add, list1)

def test_builtin_list_radd_fails(self):
@base.VersionedObjectRegistry.register_if(False)
class MyList(base.ObjectListBase, base.VersionedObject):
fields = {'objects': fields.ListOfObjectsField('MyOwnedObject')}

list1 = MyList(objects=[MyOwnedObject(baz=1)])

def add(obj):
return [] + obj

self.assertRaises(TypeError, add, list1)

0 comments on commit f351948

Please sign in to comment.