Skip to content

Commit

Permalink
Fix DynamicTable slicing and add SimpleMultiContainer (#449)
Browse files Browse the repository at this point in the history
* Fix case where non-AbstractContainer is base class

* Update container.py

* only use astype if not already requestd type

* fix slicing bugs in DynamicTable

* add SimpleMultiContainer

* add method to check if data_type is a subtype of another

* update inits

* fix slicing and test roundtrip

* update schema submodule

* update changelog

* remove unnecessary method

* fix documentation

* add more tests for DynamicTableRegion indexing

* remove f-string

* construct literal DataFrames with consistent order

* simplify conditional

* add tuple

* ignore dtype when checking dataframes

* remove tuple

* remove check on type of data in column

* remove ElementIdentifiers data object type check

* convert to list for h5py 2.9 compatability

* use assertRaisesWith

* simplify SimpleMultiContainer tests

* Update tests/unit/common/test_multi.py

Co-authored-by: Ryan Ly <rly@lbl.gov>

* pin schema revision

* update schema submodule

* Update CHANGELOG.md

Co-authored-by: Ryan Ly <rly@lbl.gov>

Co-authored-by: Ryan Ly <rly@lbl.gov>
  • Loading branch information
ajtritt and rly committed Nov 5, 2020
1 parent a15fd20 commit f4b8232
Show file tree
Hide file tree
Showing 13 changed files with 307 additions and 29 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
@bendichter, @rly (#430)
- Add capability to add a row to a column after IO. @bendichter (#426)
- Add method `hdmf.utils.get_docval_macro` to get a tuple of the current values for a docval_macro, e.g., 'array_data'
and 'scalar_data'. @rly (#456)
and 'scalar_data'. @rly (#446)
- Add SimpleMultiContainer, a data_type for storing a Container and Data objects together. @ajtritt (#449)
- Support `pathlib.Path` paths in `HDMFIO.__init__`, `HDF5IO.__init__`, and `HDF5IO.load_namespaces`. @dsleiter (#439)
- Use hdmf-common-schema 1.2.1. See https://hdmf-common-schema.readthedocs.io/en/latest/format_release_notes.html for details.

### Internal improvements
- Refactor `HDF5IO.write_dataset` to be more readable. @rly (#428)
- Fix bug in slicing tables with DynamicTableRegions. @ajtritt (#449)

### Bug fixes
- Fix development package dependency issues. @rly (#431)
Expand Down
13 changes: 13 additions & 0 deletions src/hdmf/build/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,19 @@ def get_builder_dt(self, **kwargs):
builder = getargs('builder', kwargs)
return self.__type_map.get_builder_dt(builder)

@docval({'name': 'builder', 'type': (GroupBuilder, DatasetBuilder), 'doc': 'the builder to check'},
{'name': 'parent_data_type', 'type': (str, type), 'doc': 'the potential parent data_type'},
returns="True if data_type of *builder* is a sub-data_type of *parent_data_type*, False otherwise",
rtype=bool)
def is_sub_data_type(self, **kwargs):
'''
Return whether or not data_type of *builder* is a sub-data_type of *parent_data_type*
'''
builder, parent_dt = getargs('builder', 'parent_data_type', kwargs)
dt = self.get_builder_dt(builder)
ns = self.get_builder_ns(builder)
return self.namespace_catalog.is_sub_data_type(ns, dt, parent_dt)


class TypeSource:
'''A class to indicate the source of a data_type in a namespace.
Expand Down
10 changes: 6 additions & 4 deletions src/hdmf/build/objectmapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def no_convert(cls, obj_type):
"""
cls.__no_convert.add(obj_type)

@classmethod
def convert_dtype(cls, spec, value, spec_dtype=None):
@classmethod # noqa: C901
def convert_dtype(cls, spec, value, spec_dtype=None): # noqa: C901
"""
Convert values to the specified dtype. For example, if a literal int
is passed in to a field that is specified as a unsigned integer, this function
Expand Down Expand Up @@ -205,7 +205,10 @@ def convert_dtype(cls, spec, value, spec_dtype=None):
ret_dtype = "ascii"
else:
dtype_func, warning_msg = cls.__resolve_numeric_dtype(value.dtype, spec_dtype_type)
ret = np.asarray(value).astype(dtype_func)
if value.dtype == dtype_func:
ret = value
else:
ret = value.astype(dtype_func)
ret_dtype = ret.dtype.type
elif isinstance(value, (tuple, list)):
if len(value) == 0:
Expand Down Expand Up @@ -917,7 +920,6 @@ def __add_groups(self, builder, groups, container, build_manager, source, export
% (repr(spec.name),
spec.def_key(), repr(spec.data_type_def),
spec.inc_key(), repr(spec.data_type_inc)))
attr_name = self.get_attribute(spec)
attr_value = self.get_attr_value(spec, container, build_manager)
if attr_value is not None:
self.__add_containers(builder, spec, attr_value, build_manager, source, container, export)
Expand Down
2 changes: 2 additions & 0 deletions src/hdmf/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def available_namespaces():

from . import table # noqa: F401,E402
from . import sparse # noqa: F401,E402
from . import multi # noqa: F401,E402

from .. import Data, Container
__TYPE_MAP.register_container_type(CORE_NAMESPACE, 'Container', Container)
Expand All @@ -120,6 +121,7 @@ def available_namespaces():
DynamicTableRegion = __TYPE_MAP.get_container_cls(CORE_NAMESPACE, 'DynamicTableRegion')
VocabData = __TYPE_MAP.get_container_cls(CORE_NAMESPACE, 'VocabData')
CSRMatrix = __TYPE_MAP.get_container_cls(CORE_NAMESPACE, 'CSRMatrix')
SimpleMultiContainer = __TYPE_MAP.get_container_cls(CORE_NAMESPACE, 'SimpleMultiContainer')


@docval({'name': 'extensions', 'type': (str, TypeMap, list),
Expand Down
1 change: 1 addition & 0 deletions src/hdmf/common/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from . import table # noqa: F401
from . import multi # noqa: F401
21 changes: 21 additions & 0 deletions src/hdmf/common/io/multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from ...build import ObjectMapper
from ..multi import SimpleMultiContainer
from .. import register_map
from ...container import Container, Data


@register_map(SimpleMultiContainer)
class SimpleMultiContainerMap(ObjectMapper):

@ObjectMapper.object_attr('containers')
def containers_attr(self, container, manager):
return [c for c in container.containers.values() if isinstance(c, Container)]

@ObjectMapper.constructor_arg('containers')
def containers_carg(self, builder, manager):
return [manager.construct(sub) for sub in builder.datasets.values() if manager.is_sub_data_type(sub, 'Data')] +\
[manager.construct(sub) for sub in builder.groups.values() if manager.is_sub_data_type(sub, 'Container')]

@ObjectMapper.object_attr('datas')
def datas_attr(self, container, manager):
return [c for c in container.containers.values() if isinstance(c, Data)]
23 changes: 23 additions & 0 deletions src/hdmf/common/multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from ..container import Container, Data, MultiContainerInterface
from ..utils import docval, call_docval_func, popargs

from . import register_class


@register_class('SimpleMultiContainer')
class SimpleMultiContainer(MultiContainerInterface):

__clsconf__ = {
'attr': 'containers',
'type': (Container, Data),
'add': 'add_container',
'get': 'get_container',
}

@docval({'name': 'name', 'type': str, 'doc': 'the name of this container'},
{'name': 'containers', 'type': (list, tuple), 'default': None,
'doc': 'the Container or Data objects in this file'})
def __init__(self, **kwargs):
containers = popargs('containers', kwargs)
call_docval_func(super().__init__, kwargs)
self.containers = containers
79 changes: 57 additions & 22 deletions src/hdmf/common/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def __getitem__(self, key):
ret = self.get(key)
if ret is None:
raise KeyError(key)
return self.get(key)
return ret

def get(self, key, default=None, df=True, **kwargs): # noqa: C901
"""
Expand Down Expand Up @@ -776,28 +776,17 @@ def get(self, key, default=None, df=True, **kwargs): # noqa: C901
ret['id'] = self.id.data[arg]
for name in self.colnames:
col = self.__df_cols[self.__colids[name]]
if isinstance(col.data, (Dataset, np.ndarray)) and col.data.ndim > 1:
ret[name] = col.get(arg, df=df, **kwargs)
else:
currdata = col.get(arg, df=df, **kwargs)
ret[name] = currdata
ret[name] = col.get(arg, df=df, **kwargs)
# index by a list of ints, return multiple rows
elif isinstance(arg, (tuple, list, np.ndarray)):
elif isinstance(arg, (list, np.ndarray)):
if isinstance(arg, np.ndarray):
if len(arg.shape) != 1:
raise ValueError("cannot index DynamicTable with multiple dimensions")
ret = OrderedDict()
ret['id'] = (self.id.data[arg]
if isinstance(self.id.data, np.ndarray)
else [self.id.data[i] for i in arg])
ret['id'] = self.id[arg]
for name in self.colnames:
col = self.__df_cols[self.__colids[name]]
if isinstance(col.data, (Dataset, np.ndarray)) and col.data.ndim > 1:
ret[name] = [x for x in col.get(arg, df=df, **kwargs)]
elif isinstance(col.data, (list, np.ndarray)):
ret[name] = col.get(arg, df=df, **kwargs)
else:
ret[name] = [col.get(arg, df=df, **kwargs) for i in arg]
ret[name] = col.get(arg, df=df, **kwargs)
else:
raise KeyError("Key type not supported by DynamicTable %s" % str(type(arg)))
except ValueError as ve:
Expand All @@ -815,7 +804,6 @@ def get(self, key, default=None, df=True, **kwargs): # noqa: C901
raise IndexError(msg)
else: # pragma: no cover
raise ie

if df:
# reformat objects to fit into a pandas DataFrame
id_index = ret.pop('id')
Expand Down Expand Up @@ -852,7 +840,6 @@ def get(self, key, default=None, df=True, **kwargs): # noqa: C901
retdf[newcolname] = ret[k][col].values
else:
retdf[k] = ret[k]

ret = pd.DataFrame(retdf, index=pd.Index(name=self.id.name, data=id_index))
else:
ret = list(ret.values())
Expand Down Expand Up @@ -1030,7 +1017,7 @@ def table(self, val):
def __getitem__(self, arg):
return self.get(arg)

def get(self, arg, index=False, **kwargs):
def get(self, arg, index=False, df=True, **kwargs):
"""
Subset the DynamicTableRegion
Expand All @@ -1046,16 +1033,64 @@ def get(self, arg, index=False, **kwargs):
arg1 = arg[0]
arg2 = arg[1]
return self.table[self.data[arg1], arg2]
elif isinstance(arg, slice) or np.issubdtype(type(arg), np.integer):
if np.issubdtype(type(arg), np.integer) and arg >= len(self.data):
elif np.issubdtype(type(arg), np.integer):
if arg >= len(self.data):
raise IndexError('index {} out of bounds for data of length {}'.format(arg, len(self.data)))
ret = self.data[arg]
if not index:
ret = self.table.get(ret, **kwargs)
ret = self.table.get(ret, df=df, **kwargs)
return ret
elif isinstance(arg, (list, slice, np.ndarray)):
idx = arg

# get the data at the specified indices
if isinstance(self.data, (tuple, list)) and isinstance(idx, list):
ret = [self.data[i] for i in idx]
else:
ret = self.data[idx]

# dereference them if necessary
if not index:
# These lines are needed because indexing Dataset with a list/ndarray
# of ints requires the list to be sorted.
#
# First get the unique elements, retrieve them from the table, and then
# reorder the result according to the original index that the user passed in.
#
# When not returning a DataFrame, we need to recursively sort the subelements
# of the list we are returning. This is carried out by the recursive method _index_lol
uniq = np.unique(ret)
lut = {val: i for i, val in enumerate(uniq)}
values = self.table.get(uniq, df=df, **kwargs)
if df:
ret = values.iloc[[lut[i] for i in ret]]
else:
ret = self._index_lol(values, ret, lut)

return ret
else:
raise ValueError("unrecognized argument: '%s'" % arg)

def _index_lol(self, result, index, lut):
"""
This is a helper function for indexing a list of lists/ndarrays. When not returning a
DataFrame, indexing a DynamicTable will return a list of lists and ndarrays. To sort
the result of a DynamicTable index according to the order of the indices passed in by the
user, we have to recursively sort the sub-lists/sub-ndarrays.
"""
ret = list()
for col in result:
if isinstance(col, list):
if isinstance(col[0], list):
ret.append(self._index_lol(col, index, lut))
else:
ret.append([col[lut[i]] for i in index])
elif isinstance(col, np.ndarray):
ret.append(np.array([col[lut[i]] for i in index], dtype=col.dtype))
else:
raise ValueError('unrecognized column type: %s. Expected list or np.ndarray' % type(col))
return ret

def to_dataframe(self, **kwargs):
"""
Convert the whole DynamicTableRegion to a pandas dataframe.
Expand Down
4 changes: 4 additions & 0 deletions src/hdmf/container.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import h5py
import numpy as np
from abc import ABCMeta, abstractmethod
from uuid import uuid4
Expand Down Expand Up @@ -510,6 +511,9 @@ def __getitem__(self, args):
def get(self, args):
if isinstance(self.data, (tuple, list)) and isinstance(args, (tuple, list, np.ndarray)):
return [self.data[i] for i in args]
if isinstance(self.data, h5py.Dataset) and isinstance(args, np.ndarray):
# This is needed for h5py 2.9 compatability
args = args.tolist()
return self.data[args]

def append(self, arg):
Expand Down
15 changes: 15 additions & 0 deletions src/hdmf/spec/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,21 @@ def get_hierarchy(self, **kwargs):
raise KeyError("'%s' not a namespace" % namespace)
return spec_ns.get_hierarchy(data_type)

@docval({'name': 'namespace', 'type': str, 'doc': 'the name of the namespace containing the data_type'},
{'name': 'data_type', 'type': (str, type), 'doc': 'the data_type to check'},
{'name': 'parent_data_type', 'type': (str, type), 'doc': 'the potential parent data_type'},
returns="True if *data_type* is a sub `data_type` of *parent_data_type*, False otherwise", rtype=bool)
def is_sub_data_type(self, **kwargs):
'''
Return whether or not *data_type* is a sub `data_type` of *parent_data_type*
'''
ns, dt, parent_dt = getargs('namespace', 'data_type', 'parent_data_type', kwargs)
spec_ns = self.__namespaces.get(ns)
if spec_ns is None:
raise KeyError("'%s' not a namespace" % ns)
hier = spec_ns.get_hierarchy(dt)
return parent_dt in hier

@docval(rtype=tuple)
def get_sources(self, **kwargs):
'''
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/common/test_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from hdmf.container import Container, Data
from hdmf.common import SimpleMultiContainer
from hdmf.testing import TestCase, H5RoundTripMixin


class SimpleMultiContainerRoundTrip(H5RoundTripMixin, TestCase):

def setUpContainer(self):
containers = [
Container('container1'),
Container('container2'),
Data('data1', [0, 1, 2, 3, 4]),
Data('data2', [0.0, 1.0, 2.0, 3.0, 4.0]),
]
multi_container = SimpleMultiContainer('multi', containers)
return multi_container

0 comments on commit f4b8232

Please sign in to comment.