Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions bayesflow/data_adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from . import transforms

from .composite_data_adapter import CompositeDataAdapter
from .concatenate_keys_data_adapter import ConcatenateKeysDataAdapter
from .data_adapter import DataAdapter
from .flow_matching_data_adapter import FlowMatchingDataAdapter
52 changes: 52 additions & 0 deletions bayesflow/data_adapters/composite_data_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from collections.abc import Mapping
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)
import numpy as np

from .data_adapter import DataAdapter


TRaw = Mapping[str, np.ndarray]
TProcessed = Mapping[str, np.ndarray]


@serializable(package="bayesflow.data_adapters")
class CompositeDataAdapter(DataAdapter[TRaw, TProcessed]):
"""Composes multiple simple data adapters into a single more complex adapter."""

def __init__(self, data_adapters: Mapping[str, DataAdapter[TRaw, np.ndarray | None]]):
self.data_adapters = data_adapters
self.variable_counts = None

def configure(self, raw_data: TRaw) -> TProcessed:
processed_data = {}
for key, data_adapter in self.data_adapters.items():
data = data_adapter.configure(raw_data)
if data is not None:
processed_data[key] = data

return processed_data

def deconfigure(self, processed_data: TProcessed) -> TRaw:
raw_data = {}
for key, data_adapter in self.data_adapters.items():
data = processed_data.get(key)
if data is not None:
raw_data |= data_adapter.deconfigure(data)

return raw_data

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "CompositeDataAdapter":
return cls(
{
key: deserialize(data_adapter, custom_objects)
for key, data_adapter in config.pop("data_adapters").items()
}
)

def get_config(self) -> dict:
return {"data_adapters": {key: serialize(configurator) for key, configurator in self.data_adapters.items()}}
105 changes: 105 additions & 0 deletions bayesflow/data_adapters/concatenate_keys_data_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from collections.abc import Mapping, Sequence
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)
import numpy as np

from .composite_data_adapter import CompositeDataAdapter
from .data_adapter import DataAdapter
from .transforms import Transform

TRaw = Mapping[str, np.ndarray]
TProcessed = np.ndarray | None


@serializable(package="bayesflow.data_adapters")
class _ConcatenateKeysDataAdapter(DataAdapter[TRaw, TProcessed]):
"""Concatenates data from multiple keys into a single tensor."""

def __init__(self, keys: Sequence[str]):
if not keys:
raise ValueError("At least one key must be provided.")

self.keys = keys
self.data_shapes = None
self.is_configured = False

def configure(self, raw_data: TRaw) -> TProcessed:
if not self.is_configured:
self.data_shapes = {key: value.shape for key, value in raw_data.items()}
self.is_configured = True

# filter and reorder data
data = {}
for key in self.keys:
if key not in raw_data:
# if a key is missing, we cannot configure, so we return None
return None

data[key] = raw_data[key]

# concatenate all tensors
return np.concatenate(list(data.values()), axis=-1)

def deconfigure(self, processed_data: TProcessed) -> TRaw:
if not self.is_configured:
raise ValueError("You must call `configure` at least once before calling `deconfigure`.")

data = {}
start = 0
for key in self.keys:
stop = start + self.data_shapes[key][-1]
data[key] = np.take(processed_data, list(range(start, stop)), axis=-1)
start = stop

return data

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "_ConcatenateKeysDataAdapter":
instance = cls(config["keys"])
instance.data_shapes = config.get("data_shapes")
instance.is_configured = config.get("is_configured", False)
return instance

def get_config(self) -> dict:
return {"keys": self.keys, "data_shapes": self.data_shapes, "is_configured": self.is_configured}


@serializable(package="bayesflow.data_adapters")
class ConcatenateKeysDataAdapter(CompositeDataAdapter):
"""Concatenates data from multiple keys into multiple tensors."""

def __init__(self, *, transforms: Sequence[Transform] = None, **keys: Sequence[str]):
self.transforms = transforms or []
self.keys = keys
configurators = {key: _ConcatenateKeysDataAdapter(value) for key, value in keys.items()}
super().__init__(configurators)

def configure(self, raw_data):
data = raw_data

for transform in self.transforms:
data = transform(data, inverse=False)

data = super().configure(data)

return data

def deconfigure(self, processed_data):
data = processed_data

data = super().deconfigure(data)

for transform in reversed(self.transforms):
data = transform(data, inverse=True)

return data

@classmethod
def from_config(cls, config: Mapping[str, any], custom_objects=None) -> "ConcatenateKeysDataAdapter":
return cls(**config["keys"], transforms=deserialize(config.get("transforms")))

def get_config(self) -> dict[str, any]:
return {"keys": self.keys, "transforms": serialize(self.transforms)}
29 changes: 29 additions & 0 deletions bayesflow/data_adapters/data_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Generic, TypeVar


TRaw = TypeVar("TRaw")
TProcessed = TypeVar("TProcessed")


class DataAdapter(Generic[TRaw, TProcessed]):
"""Construct and deconstruct deep-learning ready data from and into raw data."""

def configure(self, raw_data: TRaw) -> TProcessed:
"""Construct deep-learning ready data from raw data."""
raise NotImplementedError

def deconfigure(self, processed_data: TProcessed) -> TRaw:
"""Reconstruct raw data from deep-learning ready processed data.
Note that configuration is not required to be bijective, so this method is only meant to be a 'best effort'
attempt, and may return incomplete or different raw data.
"""
raise NotImplementedError

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "DataAdapter":
"""Construct a data adapter from a configuration dictionary."""
raise NotImplementedError

def get_config(self) -> dict:
"""Return a configuration dictionary."""
raise NotImplementedError
44 changes: 44 additions & 0 deletions bayesflow/data_adapters/flow_matching_data_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from keras.saving import register_keras_serializable as serializable
import numpy as np
from typing import TypeVar

from bayesflow.utils import optimal_transport

from .data_adapter import DataAdapter


TRaw = TypeVar("TRaw")
TProcessed = dict[str, np.ndarray | tuple[np.ndarray, ...]]


@serializable(package="bayesflow.data_adapters")
class FlowMatchingDataAdapter(DataAdapter[TRaw, TProcessed]):
"""Wraps a data adapter, applying all further processing required for Optimal Transport Flow Matching.
Useful to move these operations into a worker process, so as not to slow down training.
"""

def __init__(self, inner: DataAdapter[TRaw, dict[str, np.ndarray]], key: str = "inference_variables", **kwargs):
self.inner = inner
self.key = key
self.kwargs = kwargs

def configure(self, raw_data: TRaw) -> TProcessed:
processed_data = self.inner.configure(raw_data)

x1 = processed_data[self.key]
x0 = np.random.standard_normal(size=x1.shape).astype(x1.dtype)
t = np.random.uniform(size=x1.shape[0]).astype(x1.dtype)

expand_index = [slice(None)] + [None] * (x1.ndim - 1)
t = t[tuple(expand_index)]

x0, x1 = optimal_transport(x0, x1, **self.kwargs, numpy=True)

x = t * x1 + (1 - t) * x0

target_velocity = x1 - x0

return processed_data | {self.key: (x0, x1, t, x, target_velocity)}

def deconfigure(self, variables: TProcessed) -> TRaw:
return self.inner.deconfigure(variables)
5 changes: 5 additions & 0 deletions bayesflow/data_adapters/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .constrain_bounded import ConstrainBounded
from .lambda_transform import LambdaTransform
from .numpy_transform import NumpyTransform
from .standardize import Standardize
from .transform import Transform
140 changes: 140 additions & 0 deletions bayesflow/data_adapters/transforms/constrain_bounded.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from collections.abc import Sequence
from keras.saving import (
deserialize_keras_object as deserialize,
register_keras_serializable as serializable,
serialize_keras_object as serialize,
)
import numpy as np

from bayesflow.utils.numpy_utils import (
inverse_sigmoid,
inverse_softplus,
sigmoid,
softplus,
)

from .lambda_transform import LambdaTransform


@serializable(package="bayesflow.data_adapters")
class ConstrainBounded(LambdaTransform):
"""Constrains a parameter with a lower and/or upper bound."""

def __init__(
self,
parameters: str | Sequence[str] | None = None,
/,
*,
lower: np.ndarray = None,
upper: np.ndarray = None,
method: str,
):
self.lower = lower
self.upper = upper
self.method = method

if lower is None and upper is None:
raise ValueError("At least one of 'lower' or 'upper' must be provided.")

if lower is not None and upper is not None:
if np.any(lower >= upper):
raise ValueError("The lower bound must be strictly less than the upper bound.")

# double bounded case
match method:
case "clip":

def constrain(x):
return np.clip(x, lower, upper)

def unconstrain(x):
# not bijective
return x
case "sigmoid":

def constrain(x):
return (upper - lower) * sigmoid(x) + lower

def unconstrain(x):
return inverse_sigmoid((x - lower) / (upper - lower))
case str() as name:
raise ValueError(f"Unsupported method name for double bounded constraint: '{name}'.")
case other:
raise TypeError(f"Expected a method name, got {other!r}.")
else:
# single bounded case
if lower is not None:
match method:
case "clip":

def constrain(x):
return np.clip(x, lower, np.inf)

def unconstrain(x):
# not bijective
return x
case "softplus":

def constrain(x):
return softplus(x) + lower

def unconstrain(x):
return inverse_softplus(x - lower)
case "exp":

def constrain(x):
return np.exp(x) + lower

def unconstrain(x):
return np.log(x - lower)
case str() as name:
raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.")
case other:
raise TypeError(f"Expected a method name, got {other!r}.")
else:
match method:
case "clip":

def constrain(x):
return np.clip(x, -np.inf, upper)

def unconstrain(x):
# not bijective
return x
case "softplus":

def constrain(x):
return -softplus(-x) + upper

def unconstrain(x):
return -inverse_softplus(-(x - upper))
case "exp":

def constrain(x):
return -np.exp(-x) + upper

def unconstrain(x):
return -np.log(-x + upper)
case str() as name:
raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.")
case other:
raise TypeError(f"Expected a method name, got {other!r}.")

super().__init__(parameters, forward=unconstrain, inverse=constrain)

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "ConstrainBounded":
return cls(
deserialize(config["parameters"], custom_objects),
lower=deserialize(config["lower"], custom_objects),
upper=deserialize(config["upper"], custom_objects),
method=deserialize(config["method"], custom_objects),
)

def get_config(self) -> dict:
return {
"parameters": serialize(self.parameters),
"lower": serialize(self.lower),
"upper": serialize(self.upper),
"method": serialize(self.method),
}
Loading
Loading