From 25b73d3100d26527a1ec1c269ad0c12f1e5fe9fe Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 20 May 2025 07:36:55 +0000 Subject: [PATCH 1/2] Add squeeze transform Very basic transform, just the inverse of expand_dims --- bayesflow/adapters/adapter.py | 18 ++++++++++ bayesflow/adapters/transforms/__init__.py | 1 + bayesflow/adapters/transforms/squeeze.py | 42 +++++++++++++++++++++++ tests/test_adapters/conftest.py | 1 + 4 files changed, 62 insertions(+) create mode 100644 bayesflow/adapters/transforms/squeeze.py diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 6c4802a44..be2ae998a 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -22,6 +22,7 @@ OneHot, Rename, SerializableCustomTransform, + Squeeze, Sqrt, Standardize, ToArray, @@ -780,6 +781,23 @@ def split(self, key: str, *, into: Sequence[str], indices_or_sections: int | Seq return self + def squeeze(self, keys: str | Sequence[str], *, axis: int | tuple): + """Append a :py:class:`~transforms.Squeeze` transform to the adapter. + + Parameters + ---------- + keys : str or Sequence of str + The names of the variables to squeeze. + axis : int or tuple + The axis to squeeze. + """ + if isinstance(keys, str): + keys = [keys] + + transform = MapTransform({key: Squeeze(axis=axis) for key in keys}) + self.transforms.append(transform) + return self + def sqrt(self, keys: str | Sequence[str]): """Append an :py:class:`~transforms.Sqrt` transform to the adapter. diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index 2651c65a7..4b24e0bc2 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -19,6 +19,7 @@ from .serializable_custom_transform import SerializableCustomTransform from .shift import Shift from .split import Split +from .squeeze import Squeeze from .sqrt import Sqrt from .standardize import Standardize from .to_array import ToArray diff --git a/bayesflow/adapters/transforms/squeeze.py b/bayesflow/adapters/transforms/squeeze.py new file mode 100644 index 000000000..edf1bc3a0 --- /dev/null +++ b/bayesflow/adapters/transforms/squeeze.py @@ -0,0 +1,42 @@ +import numpy as np + +from bayesflow.utils.serialization import serializable, serialize + +from .elementwise_transform import ElementwiseTransform + + +@serializable("bayesflow.adapters") +class Squeeze(ElementwiseTransform): + """ + Squeeze dimensions of an array. + + Parameters + ---------- + axis : int or tuple + The axis to squeeze. + + Examples + -------- + shape (3, 1) array: + + >>> a = np.array([[1], [2], [3]]) + + >>> sq = bf.adapters.transforms.Squeeze(axis=1) + >>> sq.forward(a).shape + (3,) + + It is recommended to precede this transform with a :class:`~bayesflow.adapters.transforms.ToArray` transform. + """ + + def __init__(self, *, axis: int | tuple): + super().__init__() + self.axis = axis + + def get_config(self) -> dict: + return serialize({"axis": self.axis}) + + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + return np.squeeze(data, axis=self.axis) + + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + return np.expand_dims(data, axis=self.axis) diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index e09a74c33..3193309ae 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -21,6 +21,7 @@ def serializable_fn(x): .concatenate(["x1", "x2"], into="x") .concatenate(["y1", "y2"], into="y") .expand_dims(["z1"], axis=2) + .squeeze("z1", axis=2) .log("p1") .constrain("p2", lower=0) .apply(include="p2", forward="exp", inverse="log") From 98a6bcadecb19b34d09039b12a2915d0664a89df Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 20 May 2025 19:46:44 +0000 Subject: [PATCH 2/2] squeeze: adapt example, add comment for changing batch dims --- bayesflow/adapters/adapter.py | 3 ++- bayesflow/adapters/transforms/squeeze.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index be2ae998a..b547083a8 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -789,7 +789,8 @@ def squeeze(self, keys: str | Sequence[str], *, axis: int | tuple): keys : str or Sequence of str The names of the variables to squeeze. axis : int or tuple - The axis to squeeze. + The axis to squeeze. As the number of batch dimensions might change, we advise using negative + numbers (i.e., indexing from the end instead of the start). """ if isinstance(keys, str): keys = [keys] diff --git a/bayesflow/adapters/transforms/squeeze.py b/bayesflow/adapters/transforms/squeeze.py index edf1bc3a0..df9a10a80 100644 --- a/bayesflow/adapters/transforms/squeeze.py +++ b/bayesflow/adapters/transforms/squeeze.py @@ -13,7 +13,8 @@ class Squeeze(ElementwiseTransform): Parameters ---------- axis : int or tuple - The axis to squeeze. + The axis to squeeze. As the number of batch dimensions might change, we advise using negative + numbers (i.e., indexing from the end instead of the start). Examples -------- @@ -21,7 +22,7 @@ class Squeeze(ElementwiseTransform): >>> a = np.array([[1], [2], [3]]) - >>> sq = bf.adapters.transforms.Squeeze(axis=1) + >>> sq = bf.adapters.transforms.Squeeze(axis=-1) >>> sq.forward(a).shape (3,)