Skip to content

Commit

Permalink
Merge pull request #126 from observingClouds/py_bitinfo
Browse files Browse the repository at this point in the history
Pythonic bitinformation
  • Loading branch information
observingClouds committed Oct 21, 2022
2 parents 61a03c8 + b70268e commit ebf417e
Show file tree
Hide file tree
Showing 4 changed files with 254 additions and 71 deletions.
1 change: 1 addition & 0 deletions requirements.txt
@@ -1,3 +1,4 @@
dask
xarray
julia
tqdm
Expand Down
118 changes: 78 additions & 40 deletions tests/test_get_bitinformation.py
Expand Up @@ -63,72 +63,106 @@ def bitinfo_assert_different(bitinfo1, bitinfo2):
assert (bitinfo1 != bitinfo2).any()


def test_get_bitinformation_returns_dataset():
@pytest.mark.parametrize("implementation", ["julia", "python"])
def test_get_bitinformation_returns_dataset(implementation):
"""Test xb.get_bitinformation returns xr.Dataset."""
ds = xr.tutorial.load_dataset("rasm")
assert isinstance(xb.get_bitinformation(ds, axis=0), xr.Dataset)
assert isinstance(
xb.get_bitinformation(ds, implementation=implementation, axis=0), xr.Dataset
)


def test_get_bitinformation_dim():
@pytest.mark.parametrize("implementation", ["julia", "python"])
def test_get_bitinformation_dim(implementation):
"""Test xb.get_bitinformation is sensitive to dim."""
ds = xr.tutorial.load_dataset("rasm")
bitinfo0 = xb.get_bitinformation(ds, axis=0)
bitinfo2 = xb.get_bitinformation(ds, axis=2)
bitinfo0 = xb.get_bitinformation(ds, axis=0, implementation=implementation)
bitinfo2 = xb.get_bitinformation(ds, axis=2, implementation=implementation)
assert_different(bitinfo0, bitinfo2)


def test_get_bitinformation_dim_string_equals_axis_int():
@pytest.mark.parametrize("implementation", ["julia", "python"])
def test_get_bitinformation_dim_string_equals_axis_int(implementation):
"""Test xb.get_bitinformation undestands xarray dimension names the same way as axis as integers."""
ds = xr.tutorial.load_dataset("rasm")
bitinfox = xb.get_bitinformation(ds, dim="x")
bitinfo2 = xb.get_bitinformation(ds, axis=2)
bitinfox = xb.get_bitinformation(ds, dim="x", implementation=implementation)
bitinfo2 = xb.get_bitinformation(ds, axis=2, implementation=implementation)
assert_identical(bitinfox, bitinfo2)


def test_get_bitinformation_masked_value():
def test_get_bitinformation_masked_value(implementation="julia"):
"""Test xb.get_bitinformation is sensitive to masked_value."""
ds = xr.tutorial.load_dataset("rasm")
bitinfo = xb.get_bitinformation(ds, dim="x")
bitinfo_no_mask = xb.get_bitinformation(ds, dim="x", masked_value="nothing")
bitinfo_no_mask_None = xb.get_bitinformation(ds, dim="x", masked_value=None)
bitinfo = xb.get_bitinformation(ds, dim="x", implementation=implementation)
bitinfo_no_mask = xb.get_bitinformation(
ds, dim="x", masked_value="nothing", implementation=implementation
)
bitinfo_no_mask_None = xb.get_bitinformation(
ds, dim="x", masked_value=None, implementation=implementation
)
assert_identical(bitinfo_no_mask, bitinfo_no_mask_None)
assert_different(bitinfo, bitinfo_no_mask)


def test_get_bitinformation_set_zero_insignificant():
@pytest.mark.parametrize("implementation", ["julia", "python"])
def test_get_bitinformation_set_zero_insignificant(implementation):
"""Test xb.get_bitinformation is sensitive to set_zero_insignificant."""
ds = xr.tutorial.load_dataset("air_temperature")
dim = "lon"
bitinfo_szi_False = xb.get_bitinformation(ds, dim=dim, set_zero_insignificant=False)
bitinfo_szi_True = xb.get_bitinformation(ds, dim=dim, set_zero_insignificant=True)
bitinfo = xb.get_bitinformation(ds, dim=dim)
assert_different(bitinfo, bitinfo_szi_False)
assert_identical(bitinfo, bitinfo_szi_True)


def test_get_bitinformation_confidence():
bitinfo = xb.get_bitinformation(ds, dim=dim, implementation=implementation)
bitinfo_szi_False = xb.get_bitinformation(
ds, dim=dim, set_zero_insignificant=False, implementation=implementation
)
try:
bitinfo_szi_True = xb.get_bitinformation(
ds, dim=dim, set_zero_insignificant=True, implementation=implementation
)
assert_identical(bitinfo, bitinfo_szi_True)
except NotImplementedError:
assert implementation == "python"
if implementation == "python":
assert_identical(bitinfo, bitinfo_szi_False)
elif implementation == "julia":
assert_different(bitinfo, bitinfo_szi_False)


@pytest.mark.parametrize("implementation", ["julia", "python"])
def test_get_bitinformation_confidence(implementation):
"""Test xb.get_bitinformation is sensitive to confidence."""
ds = xr.tutorial.load_dataset("air_temperature")
dim = "lon"
bitinfo_conf99 = xb.get_bitinformation(ds, dim=dim, confidence=0.99)
bitinfo_conf50 = xb.get_bitinformation(ds, dim=dim, confidence=0.5)
bitinfo = xb.get_bitinformation(ds, dim=dim)
assert_different(bitinfo_conf99, bitinfo_conf50)
assert_identical(bitinfo, bitinfo_conf99)


def test_get_bitinformation_label(rasm):
bitinfo = xb.get_bitinformation(ds, dim=dim, implementation=implementation)
try:
bitinfo_conf99 = xb.get_bitinformation(
ds, dim=dim, confidence=0.99, implementation=implementation
)
bitinfo_conf50 = xb.get_bitinformation(
ds, dim=dim, confidence=0.5, implementation=implementation
)
assert_different(bitinfo_conf99, bitinfo_conf50)
assert_identical(bitinfo, bitinfo_conf99)
except AssertionError:
assert implementation == "python"


@pytest.mark.parametrize("implementation", ["julia", "python"])
def test_get_bitinformation_label(rasm, implementation):
"""Test xb.get_bitinformation serializes when label given."""
ds = rasm
xb.get_bitinformation(ds, dim="x", label="./tmp_testdir/rasm")
xb.get_bitinformation(
ds, dim="x", label="./tmp_testdir/rasm", implementation=implementation
)
assert os.path.exists("./tmp_testdir/rasm.json")
# second call should be faster
xb.get_bitinformation(ds, dim="x", label="./tmp_testdir/rasm")
xb.get_bitinformation(
ds, dim="x", label="./tmp_testdir/rasm", implementation=implementation
)
os.remove("./tmp_testdir/rasm.json")


@pytest.mark.parametrize("implementation", ["julia", "python"])
@pytest.mark.parametrize("dtype", ["float64", "float32", "float16"])
def test_get_bitinformation_dtype(rasm, dtype):
def test_get_bitinformation_dtype(rasm, dtype, implementation):
"""Test xb.get_bitinformation returns correct number of bits depending on dtype."""
ds = rasm.astype(dtype)
v = list(ds.data_vars)[0]
Expand All @@ -138,10 +172,11 @@ def test_get_bitinformation_dtype(rasm, dtype):
)


def test_get_bitinformation_multidim(rasm):
@pytest.mark.parametrize("implementation", ["julia", "python"])
def test_get_bitinformation_multidim(rasm, implementation):
"""Test xb.get_bitinformation runs on all dimensions by default"""
ds = rasm
bi = xb.get_bitinformation(ds)
bi = xb.get_bitinformation(ds, implementation=implementation)
# check length of dimension
assert bi.dims["dim"] == len(ds.dims)
bi_time = bi.sel(dim="time").Tair.values
Expand All @@ -152,28 +187,31 @@ def test_get_bitinformation_multidim(rasm):
assert any(bi_y != bi_x)


def test_get_bitinformation_different_variables_dims(rasm):
@pytest.mark.parametrize("implementation", ["julia", "python"])
def test_get_bitinformation_different_variables_dims(rasm, implementation):
"""Test xb.get_bitinformation runs with variables of different dimensionality"""
ds = rasm
# add variable with different dimensionality
ds["Tair_mean"] = ds.Tair.mean(dim="time")
bi = xb.get_bitinformation(ds)
bi = xb.get_bitinformation(ds, implementation=implementation)
assert all(np.isnan(bi.Tair_mean.sel(dim="time")))
bi_Tair_mean_x = bi.Tair_mean.sel(dim="x")
bi_Tair_x = bi.Tair.sel(dim="x")
assert_different(bi_Tair_mean_x, bi_Tair_x)


def test_get_bitinformation_different_dtypes(rasm):
@pytest.mark.parametrize("implementation", ["julia", "python"])
def test_get_bitinformation_different_dtypes(rasm, implementation):
ds = rasm
ds["Tair32"] = ds.Tair.astype("float32")
ds["Tair16"] = ds.Tair.astype("float16")
bi = xb.get_bitinformation(ds)
bi = xb.get_bitinformation(ds, implementation=implementation)
for bitdim in ["bit16", "bit32", "bit64"]:
assert bitdim in bi.dims
assert bitdim in bi.coords


def test_get_bitinformation_dim_list(rasm):
bi = xb.get_bitinformation(rasm, dim=["x", "y"])
@pytest.mark.parametrize("implementation", ["julia", "python"])
def test_get_bitinformation_dim_list(rasm, implementation):
bi = xb.get_bitinformation(rasm, dim=["x", "y"], implementation=implementation)
assert (bi.dim == ["x", "y"]).all()
69 changes: 69 additions & 0 deletions xbitinfo/_py_bitinfo.py
@@ -0,0 +1,69 @@
import dask.array as da
import numpy as np
import numpy.ma as nm


def bitpaircount_u1(a, b):
assert a.dtype == "u1"
assert b.dtype == "u1"
unpack_a = (
a.flatten()
.map_blocks(
np.unpackbits,
drop_axis=0,
meta=np.array((), dtype=np.uint8),
chunks=(a.size * 8,),
)
.astype("u1")
)
unpack_b = (
b.flatten()
.map_blocks(
np.unpackbits,
drop_axis=0,
meta=np.array((), dtype=np.uint8),
chunks=(b.size * 8,),
)
.astype("u1")
)
index = ((unpack_a << 1) | unpack_b).reshape(-1, 8)

selection = np.array([0, 1, 2, 3], dtype="u1")
sel = np.where((index[..., np.newaxis]) == selection, True, False)
to_return = sel.sum(axis=0).reshape(8, 2, 2)
return to_return


def bitpaircount(a, b):
assert a.dtype.kind == "u"
assert b.dtype.kind == "u"
nbytes = max(a.dtype.itemsize, b.dtype.itemsize)

a, b = np.broadcast_arrays(a, b)

bytewise_counts = []
for i in range(nbytes):
s = (nbytes - 1 - i) * 8
bitc = bitpaircount_u1((a >> s).astype("u1"), (b >> s).astype("u1"))
bytewise_counts.append(bitc)
return np.concatenate(bytewise_counts, axis=0)


def mutual_information(a, b, base=2):
size = np.prod(np.broadcast_shapes(a.shape, b.shape))
counts = bitpaircount(a, b)

p = counts.astype("float") / size
p = da.ma.masked_equal(p, 0)
pr = p.sum(axis=-1)[..., np.newaxis]
ps = p.sum(axis=-2)[..., np.newaxis, :]
mutual_info = (p * np.log(p / (pr * ps))).sum(axis=(-1, -2)) / np.log(base)
return mutual_info


def bitinformation(a, axis=0):
sa = tuple(slice(0, -1) if i == axis else slice(None) for i in range(len(a.shape)))
sb = tuple(
slice(1, None) if i == axis else slice(None) for i in range(len(a.shape))
)
return mutual_information(a[sa], a[sb])

0 comments on commit ebf417e

Please sign in to comment.