Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

keep attrs of get_bitinformation #158

Merged
merged 15 commits into from Nov 23, 2022
Merged
1 change: 1 addition & 0 deletions CHANGELOG.rst
Expand Up @@ -7,6 +7,7 @@ CHANGELOG

* Fix julia package installations for PyPi and enable installation via pip and conda (:issue:`18`, :pr:`132`, :pr:`131`) `Filipe Fernandes`_, `Mark Kittisopikul`_.
* Fix compression example for zarr-files (:issue:`119`, :pr:`121`) `Hauke Schulz`_.
* Keep ``attrs`` from input :py:func:`xbitinfo.xbitinfo.get_bitinformation`. (:issue:`154`, :pr:`158`) `Aaron Spring`_.

0.0.2 (2022-07-11)
------------------
Expand Down
7 changes: 7 additions & 0 deletions tests/test_get_bitinformation.py
Expand Up @@ -215,3 +215,10 @@ def test_get_bitinformation_different_dtypes(rasm, implementation):
def test_get_bitinformation_dim_list(rasm, implementation):
bi = xb.get_bitinformation(rasm, dim=["x", "y"], implementation=implementation)
assert (bi.dim == ["x", "y"]).all()


def test_get_bitinformation_keep_attrs(rasm):
bi = xb.get_bitinformation(rasm, dim=["x", "y"]).Tair
assert "bitinfo_units" in bi.attrs
assert bi.attrs["bitinfo_units"] == 1
assert set(rasm.Tair.attrs.keys()).issubset(bi.attrs.keys())
12 changes: 9 additions & 3 deletions xbitinfo/xbitinfo.py
Expand Up @@ -68,7 +68,10 @@ def dict_to_dataset(info_per_bit):
dims=[dim_name],
coords={dim_name: get_bit_coords(dtype_size), "dim": dim},
name=v,
attrs={"long_name": f"{v} bitwise information", "units": "1"},
attrs={
"bitinfo_long_name": f"{v} bitwise information",
"bitinfo_units": 1,
},
).astype("float64")
# add metadata
dsb.attrs = {
Expand All @@ -90,7 +93,7 @@ def dict_to_dataset(info_per_bit):
return dsb


def get_bitinformation(
def get_bitinformation( # noqa: C901
ds,
dim=None,
axis=None,
Expand Down Expand Up @@ -236,7 +239,10 @@ def get_bitinformation(
with open(label + ".json", "w") as f:
logging.debug(f"Save bitinformation to {label + '.json'}")
json.dump(info_per_bit, f, cls=JsonCustomEncoder)
return dict_to_dataset(info_per_bit)
info_per_bit = dict_to_dataset(info_per_bit)
for var in info_per_bit.data_vars: # keep attrs from input
info_per_bit[var].attrs.update(ds[var].attrs)
return info_per_bit


def _jl_get_bitinformation(ds, var, axis, dim, kwargs={}):
Expand Down