Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def get_config(self) -> dict:

return serialize(config)

def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -> dict[str, np.ndarray]:
def forward(
self, data: dict[str, any], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
"""Apply the transforms in the forward direction.

Parameters
Expand All @@ -88,22 +90,33 @@ def forward(self, data: dict[str, any], *, stage: str = "inference", **kwargs) -
The data to be transformed.
stage : str, one of ["training", "validation", "inference"]
The stage the function is called in.
log_det_jac: bool, optional
Whether to return the log determinant of the Jacobian of the transforms.
**kwargs : dict
Additional keyword arguments passed to each transform.

Returns
-------
dict
The transformed data.
dict | tuple[dict, dict]
The transformed data or tuple of transformed data and log determinant of the Jacobian.
"""
data = data.copy()
if not log_det_jac:
for transform in self.transforms:
data = transform(data, stage=stage, **kwargs)
return data

log_det_jac = {}
for transform in self.transforms:
data = transform(data, stage=stage, **kwargs)
transformed_data = transform(data, stage=stage, **kwargs)
log_det_jac = transform.log_det_jac(data, log_det_jac, **kwargs)
data = transformed_data

return data
return data, log_det_jac

def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kwargs) -> dict[str, any]:
def inverse(
self, data: dict[str, np.ndarray], *, stage: str = "inference", log_det_jac: bool = False, **kwargs
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
"""Apply the transforms in the inverse direction.

Parameters
Expand All @@ -112,24 +125,32 @@ def inverse(self, data: dict[str, np.ndarray], *, stage: str = "inference", **kw
The data to be transformed.
stage : str, one of ["training", "validation", "inference"]
The stage the function is called in.
log_det_jac: bool, optional
Whether to return the log determinant of the Jacobian of the transforms.
**kwargs : dict
Additional keyword arguments passed to each transform.

Returns
-------
dict
The transformed data.
dict | tuple[dict, dict]
The transformed data or tuple of transformed data and log determinant of the Jacobian.
"""
data = data.copy()
if not log_det_jac:
for transform in reversed(self.transforms):
data = transform(data, stage=stage, inverse=True, **kwargs)
return data

log_det_jac = {}
for transform in reversed(self.transforms):
data = transform(data, stage=stage, inverse=True, **kwargs)
log_det_jac = transform.log_det_jac(data, log_det_jac, inverse=True, **kwargs)

return data
return data, log_det_jac

def __call__(
self, data: Mapping[str, any], *, inverse: bool = False, stage="inference", **kwargs
) -> dict[str, np.ndarray]:
) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
"""Apply the transforms in the given direction.

Parameters
Expand All @@ -145,8 +166,8 @@ def __call__(

Returns
-------
dict
The transformed data.
dict | tuple[dict, dict]
The transformed data or tuple of transformed data and log determinant of the Jacobian.
"""
if inverse:
return self.inverse(data, stage=stage, **kwargs)
Expand Down
34 changes: 34 additions & 0 deletions bayesflow/adapters/transforms/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,37 @@ def extra_repr(self) -> str:
result += f", axis={self.axis}"

return result

def log_det_jac(
self,
data: dict[str, np.ndarray],
log_det_jac: dict[str, np.ndarray],
*,
strict: bool = False,
inverse: bool = False,
**kwargs,
) -> dict[str, np.ndarray]:
# copy to avoid side effects
log_det_jac = log_det_jac.copy()

if inverse:
if log_det_jac.get(self.into) is not None:
raise ValueError(
"Cannot obtain an inverse Jacobian of concatenation. "
"Transform your variables before you concatenate."
)

return log_det_jac

required_keys = set(self.keys)
available_keys = set(log_det_jac.keys())
common_keys = available_keys & required_keys

if len(common_keys) == 0:
return log_det_jac

parts = [log_det_jac.pop(key) for key in common_keys]

log_det_jac[self.into] = sum(parts)

return log_det_jac
29 changes: 29 additions & 0 deletions bayesflow/adapters/transforms/constrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def constrain(x):

def unconstrain(x):
return inverse_sigmoid((x - lower) / (upper - lower))

def ldj(x):
x = (x - lower) / (upper - lower)
return -np.log(x) - np.log1p(-x) - np.log(upper - lower)

case str() as name:
raise ValueError(f"Unsupported method name for double bounded constraint: '{name}'.")
case other:
Expand All @@ -101,13 +106,22 @@ def constrain(x):

def unconstrain(x):
return inverse_softplus(x - lower)

def ldj(x):
x = x - lower
return x - np.log(np.exp(x) - 1)

case "exp" | "log":

def constrain(x):
return np.exp(x) + lower

def unconstrain(x):
return np.log(x - lower)

def ldj(x):
return -np.log(x - lower)

case str() as name:
raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.")
case other:
Expand All @@ -122,13 +136,21 @@ def constrain(x):

def unconstrain(x):
return -inverse_softplus(-(x - upper))

def ldj(x):
x = -(x - upper)
return x - np.log(np.exp(x) - 1)

case "exp" | "log":

def constrain(x):
return -np.exp(-x) + upper

def unconstrain(x):
return -np.log(-x + upper)

def ldj(x):
return -np.log(-x + upper)
case str() as name:
raise ValueError(f"Unsupported method name for single bounded constraint: '{name}'.")
case other:
Expand All @@ -142,6 +164,7 @@ def unconstrain(x):

self.constrain = constrain
self.unconstrain = unconstrain
self.ldj = ldj

# do this last to avoid serialization issues
match inclusive:
Expand Down Expand Up @@ -178,3 +201,9 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:
def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
# inverse means network space -> data space, so constrain the data
return self.constrain(data)

def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
ldj = self.ldj(data)
if inverse:
ldj = -ldj
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))
3 changes: 3 additions & 0 deletions bayesflow/adapters/transforms/drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@

def extra_repr(self) -> str:
return "[" + ", ".join(map(repr, self.keys)) + "]"

def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs):
return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac)

Check warning on line 51 in bayesflow/adapters/transforms/drop.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/drop.py#L51

Added line #L51 was not covered by tests
3 changes: 3 additions & 0 deletions bayesflow/adapters/transforms/elementwise_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
raise NotImplementedError

def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray | None:
return None

Check warning on line 30 in bayesflow/adapters/transforms/elementwise_transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/elementwise_transform.py#L30

Added line #L30 was not covered by tests
30 changes: 28 additions & 2 deletions bayesflow/adapters/transforms/filter_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,35 @@
return predicate(key, value, inverse=inverse)

def _apply_transform(self, key: str, value: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
transform = self._get_transform(key)

return transform(value, inverse=inverse, **kwargs)

def _get_transform(self, key: str) -> ElementwiseTransform:
if key not in self.transform_map:
self.transform_map[key] = self.transform_constructor(**self.kwargs)

transform = self.transform_map[key]
return self.transform_map[key]

return transform(value, inverse=inverse, **kwargs)
def log_det_jac(
self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = True, **kwargs
):
data = data.copy()

if strict and self.include is not None:
missing_keys = set(self.include) - set(data.keys())
if missing_keys:
raise KeyError(f"Missing keys from include list: {missing_keys!r}")

Check warning on line 171 in bayesflow/adapters/transforms/filter_transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/filter_transform.py#L171

Added line #L171 was not covered by tests

for key, value in data.items():
if self._should_transform(key, value, inverse=False):
transform = self._get_transform(key)
ldj = transform.log_det_jac(value, **kwargs)
if ldj is None:
continue

Check warning on line 178 in bayesflow/adapters/transforms/filter_transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/filter_transform.py#L178

Added line #L178 was not covered by tests
elif key in log_det_jac:
log_det_jac[key] += ldj

Check warning on line 180 in bayesflow/adapters/transforms/filter_transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/filter_transform.py#L180

Added line #L180 was not covered by tests
else:
log_det_jac[key] = ldj

return log_det_jac
3 changes: 3 additions & 0 deletions bayesflow/adapters/transforms/keep.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,6 @@

def extra_repr(self) -> str:
return "[" + ", ".join(map(repr, self.keys)) + "]"

def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs):
return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac)

Check warning on line 62 in bayesflow/adapters/transforms/keep.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/keep.py#L62

Added line #L62 was not covered by tests
9 changes: 9 additions & 0 deletions bayesflow/adapters/transforms/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,12 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:

def get_config(self) -> dict:
return serialize({"p1": self.p1})

def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
if self.p1:
ldj = -np.log1p(data)
else:
ldj = -np.log(data)
if inverse:
ldj = -ldj
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))
45 changes: 33 additions & 12 deletions bayesflow/adapters/transforms/map_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,8 @@
def forward(self, data: dict[str, np.ndarray], *, strict: bool = True, **kwargs) -> dict[str, np.ndarray]:
data = data.copy()

required_keys = set(self.transform_map.keys())
available_keys = set(data.keys())
missing_keys = required_keys - available_keys

if strict and missing_keys:
raise KeyError(f"Missing keys: {missing_keys!r}")
if strict:
self._check_keys(data)

for key, transform in self.transform_map.items():
if key in data:
Expand All @@ -57,15 +53,40 @@
def inverse(self, data: dict[str, np.ndarray], *, strict: bool = False, **kwargs) -> dict[str, np.ndarray]:
data = data.copy()

required_keys = set(self.transform_map.keys())
available_keys = set(data.keys())
missing_keys = required_keys - available_keys

if strict and missing_keys:
raise KeyError(f"Missing keys: {missing_keys!r}")
if strict:
self._check_keys(data)

Check warning on line 57 in bayesflow/adapters/transforms/map_transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/map_transform.py#L57

Added line #L57 was not covered by tests

for key, transform in self.transform_map.items():
if key in data:
data[key] = transform.inverse(data[key], **kwargs)

return data

def log_det_jac(
self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], *, strict: bool = True, **kwargs
) -> dict[str, np.ndarray]:
data = data.copy()

if strict:
self._check_keys(data)

for key, transform in self.transform_map.items():
if key in data:
ldj = transform.log_det_jac(data[key], **kwargs)

if ldj is None:
continue

Check warning on line 78 in bayesflow/adapters/transforms/map_transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/map_transform.py#L78

Added line #L78 was not covered by tests
elif key in log_det_jac:
log_det_jac[key] += ldj
else:
log_det_jac[key] = ldj

return log_det_jac

def _check_keys(self, data: dict[str, np.ndarray]):
required_keys = set(self.transform_map.keys())
available_keys = set(data.keys())
missing_keys = required_keys - available_keys

if missing_keys:
raise KeyError(f"Missing keys: {missing_keys!r}")

Check warning on line 92 in bayesflow/adapters/transforms/map_transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/map_transform.py#L92

Added line #L92 was not covered by tests
3 changes: 3 additions & 0 deletions bayesflow/adapters/transforms/numpy_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,6 @@

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return self._inverse(data)

def log_det_jac(self, data, inverse=False, **kwargs):
raise NotImplementedError("log determinant of the Jacobian of the numpy transforms are not implemented yet")

Check warning on line 77 in bayesflow/adapters/transforms/numpy_transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/numpy_transform.py#L77

Added line #L77 was not covered by tests
3 changes: 3 additions & 0 deletions bayesflow/adapters/transforms/rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,6 @@ def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> di

def extra_repr(self) -> str:
return f"{self.from_key!r} -> {self.to_key!r}"

def log_det_jac(self, data: dict[str, any], log_det_jac: dict[str, any], inverse: bool = False, **kwargs):
return self.inverse(data=log_det_jac) if inverse else self.forward(data=log_det_jac, strict=False)
7 changes: 7 additions & 0 deletions bayesflow/adapters/transforms/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,10 @@ def forward(self, data: np.ndarray, **kwargs) -> np.ndarray:

def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
return data / self.scale

def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
ldj = np.log(np.abs(self.scale))
ldj = np.full(data.shape, ldj)
if inverse:
ldj = -ldj
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))
6 changes: 6 additions & 0 deletions bayesflow/adapters/transforms/sqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,9 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:

def get_config(self) -> dict:
return {}

def log_det_jac(self, data: np.ndarray, inverse: bool = False, **kwargs) -> np.ndarray:
ldj = -0.5 * np.log(data) - np.log(2)
if inverse:
ldj = -ldj
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))
7 changes: 7 additions & 0 deletions bayesflow/adapters/transforms/standardize.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,10 @@ def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray:
std = np.broadcast_to(self.std, data.shape)

return data * std + mean

def log_det_jac(self, data, inverse: bool = False, **kwargs) -> np.ndarray:
std = np.broadcast_to(self.std, data.shape)
ldj = np.log(np.abs(std))
if inverse:
ldj = -ldj
return np.sum(ldj, axis=tuple(range(1, ldj.ndim)))
5 changes: 5 additions & 0 deletions bayesflow/adapters/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,8 @@

def extra_repr(self) -> str:
return ""

def log_det_jac(
self, data: dict[str, np.ndarray], log_det_jac: dict[str, np.ndarray], inverse: bool = False, **kwargs
) -> dict[str, np.ndarray]:
return log_det_jac

Check warning on line 42 in bayesflow/adapters/transforms/transform.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/adapters/transforms/transform.py#L42

Added line #L42 was not covered by tests
Loading