Skip to content

Commit

Permalink
Fix case where non-AbstractContainer is base class (#444)
Browse files Browse the repository at this point in the history
  • Loading branch information
rly committed Oct 23, 2020
1 parent 06064be commit 64c1c97
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/hdmf/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,20 @@ def __gather_fields(cls, name, bases, classdict):

# check whether this class overrides __fields__
if len(bases):
base_fields = bases[-1]._get_fields() # tuple of field names from base class
# find highest base class that is an AbstractContainer (parent is higher than children)
base_cls = None
for base_cls in reversed(bases):
if issubclass(base_cls, AbstractContainer):
break
base_fields = base_cls._get_fields() # tuple of field names from base class
if base_fields is not fields:
# check whether new fields spec already exists in base class
for field_name in fields_dict:
if field_name in base_fields:
raise ValueError("Field '%s' cannot be defined in %s. It already exists on base class %s."
% (field_name, cls.__name__, bases[-1].__name__))
% (field_name, cls.__name__, base_cls.__name__))
# prepend field specs from base class to fields list of this class
all_fields_conf[0:0] = bases[-1].get_fields_conf()
all_fields_conf[0:0] = base_cls.get_fields_conf()

# create getter and setter if attribute does not already exist
# if 'doc' not specified in __fields__, use doc from docval of __init__
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,32 @@ class NamedFields(AbstractContainer):
class NamedFieldsChild(NamedFields):
__fields__ = ({'name': 'field1', 'settable': True}, )

def test_mult_inheritance_base_mixin(self):
class NamedFields(AbstractContainer):
__fields__ = ({'name': 'field1', 'doc': 'field1 doc', 'settable': False}, )

class BlankMixin:
pass

class NamedFieldsChild(NamedFields, BlankMixin):
__fields__ = ({'name': 'field2'}, )

self.assertTupleEqual(NamedFieldsChild.__fields__, ('field1', 'field2'))
self.assertIs(NamedFieldsChild._get_fields(), NamedFieldsChild.__fields__)

def test_mult_inheritance_base_container(self):
class NamedFields(AbstractContainer):
__fields__ = ({'name': 'field1', 'doc': 'field1 doc', 'settable': False}, )

class BlankMixin:
pass

class NamedFieldsChild(BlankMixin, NamedFields):
__fields__ = ({'name': 'field2'}, )

self.assertTupleEqual(NamedFieldsChild.__fields__, ('field1', 'field2'))
self.assertIs(NamedFieldsChild._get_fields(), NamedFieldsChild.__fields__)


class TestContainerFieldsConf(TestCase):

Expand Down

0 comments on commit 64c1c97

Please sign in to comment.