diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 0e1c68466..e6b898d87 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -1,4 +1,4 @@ -from collections.abc import Callable, Sequence +from collections.abc import Callable, MutableSequence, Sequence import numpy as np from keras.saving import ( @@ -25,17 +25,16 @@ ToArray, Transform, ) - from .transforms.filter_transform import Predicate @serializable(package="bayesflow.adapters") -class Adapter: +class Adapter(MutableSequence[Transform]): def __init__(self, transforms: Sequence[Transform] | None = None): if transforms is None: transforms = [] - self.transforms = transforms + self.transforms = list(transforms) @staticmethod def create_default(inference_variables: Sequence[str]) -> "Adapter": @@ -76,12 +75,70 @@ def __call__(self, data: dict[str, any], *, inverse: bool = False, **kwargs) -> return self.forward(data, **kwargs) def __repr__(self): - return f"Adapter([{' -> '.join(map(repr, self.transforms))}])" + result = "" + for i, transform in enumerate(self): + result += f"{i}: {transform!r}" + if i != len(self) - 1: + result += " -> " + + return f"Adapter([{result}])" + + # list methods + + def append(self, value: Transform) -> "Adapter": + self.transforms.append(value) + return self + + def __delitem__(self, key: int | slice): + del self.transforms[key] + + def extend(self, values: Sequence[Transform]) -> "Adapter": + if isinstance(values, Adapter): + values = values.transforms + + self.transforms.extend(values) + + return self + + def __getitem__(self, item: int | slice) -> "Adapter": + if isinstance(item, int): + return self.transforms[item] + + return Adapter(self.transforms[item]) + + def insert(self, index: int, value: Transform | Sequence[Transform]) -> "Adapter": + if isinstance(value, Adapter): + value = value.transforms + + if isinstance(value, Sequence): + # convenience: Adapters are always flat + self.transforms = self.transforms[:index] + list(value) + self.transforms[index:] + else: + self.transforms.insert(index, value) + + return self + + def __setitem__(self, key: int | slice, value: Transform | Sequence[Transform]) -> "Adapter": + if isinstance(value, Adapter): + value = value.transforms + + if isinstance(key, int) and isinstance(value, Sequence): + if key < 0: + key += len(self.transforms) + + key = slice(key, key + 1) + + self.transforms[key] = value - def add_transform(self, transform: Transform): - self.transforms.append(transform) return self + def __len__(self): + return len(self.transforms) + + # adapter methods + + add_transform = append + def apply( self, *,