Skip to content

Commit

Permalink
implement add_usertypes/get_usertypes, fix tests, remove stale code a…
Browse files Browse the repository at this point in the history
…nd comments
  • Loading branch information
kmuehlbauer committed Jan 26, 2024
1 parent 6dc11d3 commit ee46d4c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 26 deletions.
65 changes: 40 additions & 25 deletions h5netcdf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,16 +590,22 @@ def _check_dtype(self, dtype):
f" file {dtype._root._h5file.filename}"
)
# check if committed type can be accessed in current group hierarchy
dname = dtype.name.split("/")[-1]
if (
(user_type := self._all_usertypes.get(dname)) is None
) or self._root._h5file[user_type].name != h5type.name:
user_type = self._get_usertype(dtype)
if user_type is None:
msg = (
f"Given dtype {dtype.name!r} is not accessible in current group"
f" {self._h5group.name!r} or any parent group. Instead it's defined at"
f" {h5type.name!r}. Please create it in the current or any parent group."
)
raise TypeError(msg)
# this checks for committed types which are overridden by re-definitions
elif (actual := user_type._h5ds.name) != h5type.name:
msg = (
f"Given dtype {dtype.name!r} is defined at {h5type.name!r}."
f" Another dtype with same name is defined at {actual!r} and"
f" would override it."
)
raise TypeError(msg)
elif np.dtype(dtype).kind == "c":
itemsize = np.dtype(dtype).itemsize
try:
Expand All @@ -610,13 +616,11 @@ def _check_dtype(self, dtype):
) from e
dname = f"_PFNC_{width}_COMPLEX_TYPE"
# todo check compound type for existing complex types
# which may be used her
# which may be used here
# if dname is not available in current group-path
# create and commit type in current group
if dname not in self._all_cmptypes:
dtype = self.create_cmptype(dtype, dname).dtype
# get committed type from file
# return self._all_cmptypes[dname]._h5ds

return dtype

Expand Down Expand Up @@ -726,23 +730,20 @@ def __init__(self, parent, name):
self._enumtypes = _LazyObjectLookup(self, self._enumtype_cls)
self._vltypes = _LazyObjectLookup(self, self._vltype_cls)
self._cmptypes = _LazyObjectLookup(self, self._cmptype_cls)
self._usertypes = dict()

# this map keeps track of all dimensions
if parent is self:
self._all_dimensions = ChainMap(self._dimensions)
self._all_enumtypes = ChainMap(self._enumtypes)
self._all_vltypes = ChainMap(self._vltypes)
self._all_cmptypes = ChainMap(self._cmptypes)
self._all_usertypes = ChainMap(self._usertypes)

else:
self._all_dimensions = parent._all_dimensions.new_child(self._dimensions)
self._all_h5groups = parent._all_h5groups.new_child(self._h5group)
self._all_enumtypes = parent._all_enumtypes.new_child(self._enumtypes)
self._all_vltypes = parent._all_vltypes.new_child(self._vltypes)
self._all_cmptypes = parent._all_cmptypes.new_child(self._cmptypes)
self._all_usertypes = parent._all_usertypes.new_child(self._usertypes)

self._variables = _LazyObjectLookup(self, self._variable_cls)
self._groups = _LazyObjectLookup(self, self._group_cls)
Expand All @@ -757,13 +758,8 @@ def __init__(self, parent, name):
# instance
self._groups.add(k)
elif isinstance(v, self._root._h5py.Datatype):
if self._root._h5py.check_enum_dtype(v.dtype):
self._enumtypes.add(k)
elif self._root._h5py.check_vlen_dtype(v.dtype):
self._vltypes.add(k)
elif v.dtype.names is not None or "complex" in v.dtype.name:
self._cmptypes.add(k)
self._usertypes[k] = v.name
# add usertypes (enum, vlen, compound)
self._add_usertype(v)
else:
if v.attrs.get("CLASS") == b"DIMENSION_SCALE":
# add dimension and retrieve size
Expand Down Expand Up @@ -1154,6 +1150,33 @@ def groups(self):
def variables(self):
return Frozen(self._variables)

def _add_usertype(self, usertype):
"""Add usertype to dicts on read"""
name = usertype.name.split("/")[-1]
dtype = usertype.dtype
metadata = dtype.metadata if dtype.metadata else {}
if "enum" in metadata:
self._enumtypes.add(name)
elif "vlen" in metadata:
self._vltypes.add(name)
elif dtype.names is not None or "complex" in dtype.name:
self._cmptypes.add(name)
else:
raise ValueError(f"Undefined user type {name}!r.")

def _get_usertype(self, usertype):
"""Add usertype to dicts on read"""
dtype = usertype.dtype
metadata = dtype.metadata if dtype.metadata else {}
if "enum" in metadata:
return self._all_enumtypes.get(usertype.name)
if "vlen" in metadata:
return self._all_vltypes.get(usertype.name)
elif dtype.names is not None or "complex" in dtype.name:
return self._all_cmptypes.get(usertype.name)
else:
raise ValueError(f"Undefined user type {dtype}!r.")

@property
def enumtypes(self):
return Frozen(self._enumtypes)
Expand All @@ -1166,10 +1189,6 @@ def vltypes(self):
def cmptypes(self):
return Frozen(self._cmptypes)

@property
def usertypes(self):
return Frozen(self._usertypes)

@property
def dims(self):
return Frozen(self._dimensions)
Expand Down Expand Up @@ -1233,9 +1252,7 @@ def create_enumtype(self, datatype, datatype_name, enum_dict):
self._h5group[datatype_name] = et
# create enumtype class instance
enumtype = self._enumtype_cls(self, datatype_name)
# enumtype = self._usertype_cls(self, datatype_name)
self._enumtypes[datatype_name] = enumtype
self._usertypes[datatype_name] = enumtype._h5ds.name
return enumtype

def create_vltype(self, datatype, datatype_name):
Expand All @@ -1252,7 +1269,6 @@ def create_vltype(self, datatype, datatype_name):
# create vltype class instance
vltype = self._vltype_cls(self, datatype_name)
self._vltypes[datatype_name] = vltype
self._usertypes[datatype_name] = vltype._h5ds.name
return vltype

def create_cmptype(self, datatype, datatype_name):
Expand All @@ -1268,7 +1284,6 @@ def create_cmptype(self, datatype, datatype_name):
# create compound class instance
cmptype = self._cmptype_cls(self, datatype_name)
self._cmptypes[datatype_name] = cmptype
self._usertypes[datatype_name] = cmptype._h5ds.name
return cmptype


Expand Down
18 changes: 17 additions & 1 deletion h5netcdf/tests/test_h5netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2234,6 +2234,7 @@ def test_user_type_errors_new_api(tmp_local_or_remote_netcdf):
ds.create_enumtype(np.uint8, "enum_t", enum_dict2)

enum_type2 = g.create_enumtype(np.uint8, "enum_t2", enum_dict2)
g.create_enumtype(np.uint8, "enum_t", enum_dict2)
with pytest.raises(TypeError, match="Please provide h5netcdf user type"):
ds.create_variable(
"enum_var1",
Expand All @@ -2255,6 +2256,13 @@ def test_user_type_errors_new_api(tmp_local_or_remote_netcdf):
dtype=enum_type2,
fillvalue=enum_dict2["missing"],
)
with pytest.raises(TypeError, match="Another dtype with same name"):
g.create_variable(
"enum_var4",
("enum_dim",),
dtype=enum_type,
fillvalue=enum_dict2["missing"],
)


def test_user_type_errors_legacyapi(tmp_local_or_remote_netcdf):
Expand All @@ -2274,7 +2282,7 @@ def test_user_type_errors_legacyapi(tmp_local_or_remote_netcdf):
ds.createEnumType(np.uint8, "enum_t", enum_dict1)

enum_type2 = g.createEnumType(np.uint8, "enum_t2", enum_dict2)

g.create_enumtype(np.uint8, "enum_t", enum_dict2)
with pytest.raises(TypeError, match="Please provide h5netcdf user type"):
ds.createVariable(
"enum_var1",
Expand All @@ -2296,6 +2304,13 @@ def test_user_type_errors_legacyapi(tmp_local_or_remote_netcdf):
("enum_dim",),
fill_value=enum_dict2["missing"],
)
with pytest.raises(TypeError, match="Another dtype with same name"):
g.createVariable(
"enum_var4",
enum_type,
("enum_dim",),
fill_value=enum_dict2["missing"],
)


def test_enum_type_errors_new_api(tmp_local_or_remote_netcdf):
Expand Down Expand Up @@ -2567,6 +2582,7 @@ def test_compoundtype_creation(tmp_local_or_remote_netcdf, netcdf_write_module):
pytest.skip("does not work for netCDF4")
with netcdf_write_module.Dataset(tmp_local_or_remote_netcdf, "w") as ds:
ds.createDimension("x", 5)
ds.createGroup("test")
compound_t = ds.createCompoundType(compound, "cmp_t")
var = ds.createVariable("data", compound_t, ("x",))
var[:] = cmp_array
Expand Down

0 comments on commit ee46d4c

Please sign in to comment.