diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 6c4802a44..b547083a8 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -22,6 +22,7 @@ OneHot, Rename, SerializableCustomTransform, + Squeeze, Sqrt, Standardize, ToArray, @@ -780,6 +781,24 @@ def split(self, key: str, *, into: Sequence[str], indices_or_sections: int | Seq return self + def squeeze(self, keys: str | Sequence[str], *, axis: int | tuple): + """Append a :py:class:`~transforms.Squeeze` transform to the adapter. + + Parameters + ---------- + keys : str or Sequence of str + The names of the variables to squeeze. + axis : int or tuple + The axis to squeeze. As the number of batch dimensions might change, we advise using negative + numbers (i.e., indexing from the end instead of the start). + """ + if isinstance(keys, str): + keys = [keys] + + transform = MapTransform({key: Squeeze(axis=axis) for key in keys}) + self.transforms.append(transform) + return self + def sqrt(self, keys: str | Sequence[str]): """Append an :py:class:`~transforms.Sqrt` transform to the adapter. diff --git a/bayesflow/adapters/transforms/__init__.py b/bayesflow/adapters/transforms/__init__.py index 2651c65a7..4b24e0bc2 100644 --- a/bayesflow/adapters/transforms/__init__.py +++ b/bayesflow/adapters/transforms/__init__.py @@ -19,6 +19,7 @@ from .serializable_custom_transform import SerializableCustomTransform from .shift import Shift from .split import Split +from .squeeze import Squeeze from .sqrt import Sqrt from .standardize import Standardize from .to_array import ToArray diff --git a/bayesflow/adapters/transforms/squeeze.py b/bayesflow/adapters/transforms/squeeze.py new file mode 100644 index 000000000..df9a10a80 --- /dev/null +++ b/bayesflow/adapters/transforms/squeeze.py @@ -0,0 +1,43 @@ +import numpy as np + +from bayesflow.utils.serialization import serializable, serialize + +from .elementwise_transform import ElementwiseTransform + + +@serializable("bayesflow.adapters") +class Squeeze(ElementwiseTransform): + """ + Squeeze dimensions of an array. + + Parameters + ---------- + axis : int or tuple + The axis to squeeze. As the number of batch dimensions might change, we advise using negative + numbers (i.e., indexing from the end instead of the start). + + Examples + -------- + shape (3, 1) array: + + >>> a = np.array([[1], [2], [3]]) + + >>> sq = bf.adapters.transforms.Squeeze(axis=-1) + >>> sq.forward(a).shape + (3,) + + It is recommended to precede this transform with a :class:`~bayesflow.adapters.transforms.ToArray` transform. + """ + + def __init__(self, *, axis: int | tuple): + super().__init__() + self.axis = axis + + def get_config(self) -> dict: + return serialize({"axis": self.axis}) + + def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + return np.squeeze(data, axis=self.axis) + + def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + return np.expand_dims(data, axis=self.axis) diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index e09a74c33..3193309ae 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -21,6 +21,7 @@ def serializable_fn(x): .concatenate(["x1", "x2"], into="x") .concatenate(["y1", "y2"], into="y") .expand_dims(["z1"], axis=2) + .squeeze("z1", axis=2) .log("p1") .constrain("p2", lower=0) .apply(include="p2", forward="exp", inverse="log")