Skip to content
71 changes: 64 additions & 7 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Callable, Sequence
from collections.abc import Callable, MutableSequence, Sequence

import numpy as np
from keras.saving import (
Expand All @@ -25,17 +25,16 @@
ToArray,
Transform,
)

from .transforms.filter_transform import Predicate


@serializable(package="bayesflow.adapters")
class Adapter:
class Adapter(MutableSequence[Transform]):
def __init__(self, transforms: Sequence[Transform] | None = None):
if transforms is None:
transforms = []

self.transforms = transforms
self.transforms = list(transforms)

@staticmethod
def create_default(inference_variables: Sequence[str]) -> "Adapter":
Expand Down Expand Up @@ -76,12 +75,70 @@ def __call__(self, data: dict[str, any], *, inverse: bool = False, **kwargs) ->
return self.forward(data, **kwargs)

def __repr__(self):
return f"Adapter([{' -> '.join(map(repr, self.transforms))}])"
result = ""
for i, transform in enumerate(self):
result += f"{i}: {transform!r}"
if i != len(self) - 1:
result += " -> "

return f"Adapter([{result}])"

# list methods

def append(self, value: Transform) -> "Adapter":
self.transforms.append(value)
return self

def __delitem__(self, key: int | slice):
del self.transforms[key]

def extend(self, values: Sequence[Transform]) -> "Adapter":
if isinstance(values, Adapter):
values = values.transforms

self.transforms.extend(values)

return self

def __getitem__(self, item: int | slice) -> "Adapter":
if isinstance(item, int):
return self.transforms[item]

return Adapter(self.transforms[item])

def insert(self, index: int, value: Transform | Sequence[Transform]) -> "Adapter":
if isinstance(value, Adapter):
value = value.transforms

if isinstance(value, Sequence):
# convenience: Adapters are always flat
self.transforms = self.transforms[:index] + list(value) + self.transforms[index:]
else:
self.transforms.insert(index, value)

return self

def __setitem__(self, key: int | slice, value: Transform | Sequence[Transform]) -> "Adapter":
if isinstance(value, Adapter):
value = value.transforms

if isinstance(key, int) and isinstance(value, Sequence):
if key < 0:
key += len(self.transforms)

key = slice(key, key + 1)

self.transforms[key] = value

def add_transform(self, transform: Transform):
self.transforms.append(transform)
return self

def __len__(self):
return len(self.transforms)

# adapter methods

add_transform = append

def apply(
self,
*,
Expand Down
Loading