Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
9090112
Splines draft
stefanradev93 Nov 28, 2024
38501d5
update keras requirement
LarsKue Jan 13, 2025
afb7f37
small improvements to error messages
LarsKue Jan 13, 2025
95a8528
add rq spline function
LarsKue Jan 13, 2025
ffd1dd1
add spline transform
LarsKue Jan 13, 2025
2a8a3ea
update searchsorted utils for jax
LarsKue Jan 13, 2025
601b0c5
update tests
LarsKue Jan 13, 2025
9974a71
add assert_allclose util for improved messages
LarsKue Jan 13, 2025
b1e52c2
parametrize transform for flow tests
LarsKue Jan 13, 2025
f688454
update jacobian, jacobian trace, vjp, jvp, and corresponding usages a…
LarsKue Jan 14, 2025
8040020
Merge branch 'dev' into splines
LarsKue Jan 14, 2025
d58fbaa
Merge branch 'dev' into splines
LarsKue Jan 14, 2025
91ad531
fix imports, remove old jacobian and jvp, fix application in free for…
LarsKue Jan 14, 2025
47e28aa
improve logdet computation in free form flows
LarsKue Jan 14, 2025
c3e72d9
Fix comparison for symbolic tensors under tf
stefanradev93 Jan 18, 2025
f4d41a9
Add splines to twomoons notebook
stefanradev93 Jan 18, 2025
12a80f8
improve pad utility
LarsKue Jan 20, 2025
4861dfa
fix missing left edge in spline
LarsKue Jan 20, 2025
e59055a
fix inside mask edge case
LarsKue Jan 20, 2025
8a4c2dd
explicitly set bias initializer
LarsKue Jan 21, 2025
a1ce42e
add better expand utility
LarsKue Jan 21, 2025
6861cdb
small clean up, renaming
LarsKue Jan 21, 2025
577a44e
fix indexing, fix inside check
LarsKue Jan 21, 2025
543281c
dump
LarsKue Jan 23, 2025
0d907e4
fix sign of log jacobian for inverse pass in rq spline
LarsKue Jan 23, 2025
dad61cf
fix parameter splitting for spline transform
LarsKue Jan 23, 2025
ef7de59
improve readability
LarsKue Jan 23, 2025
c89c5d0
fix scale and shift trailing dimension
LarsKue Jan 23, 2025
00aeb0c
fix inverse pass return value
LarsKue Jan 23, 2025
abff663
correctly choose bins once for each dimension, even for multi-dimensi…
LarsKue Jan 23, 2025
1cd2fb5
run formatter
LarsKue Jan 23, 2025
8bda832
Merge branch 'dev' into splines
LarsKue Jan 23, 2025
62e1ef5
reduce searchsorted log spam
LarsKue Jan 23, 2025
af26ba6
log backend used at setup
LarsKue Jan 23, 2025
20814b7
remove maximum message cache size
LarsKue Jan 23, 2025
6a526dd
Improve warning message for jax searchsorted
LarsKue Jan 23, 2025
08a2182
Fix spline parameter binning for compiled contexts
LarsKue Jan 23, 2025
a3ce91a
update inverse transform same as forward
LarsKue Jan 24, 2025
2454c9b
Update TwoMoons notebook with splines WIP [skip ci]
stefanradev93 Jan 26, 2025
0c9c9fd
fix spline inverse call for out of bounds values
LarsKue Jan 27, 2025
4742853
Merge remote-tracking branch 'origin/splines' into splines
LarsKue Jan 27, 2025
db8dab1
Add working splines
stefanradev93 Jan 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bayesflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def setup():

torch.autograd.set_grad_enabled(False)

from bayesflow.utils import logging

logging.info(f"Using backend {keras.backend.backend()!r}")


# call and clean up namespace
setup()
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/diagnostics/plots/calibration_ecdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def calibration_ecdf(
titles = ["Stacked ECDFs"]

for ax, title in zip(plot_data["axes"].flat, titles):
ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands")
ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
ax.legend(fontsize=legend_fontsize)
ax.set_title(title, fontsize=title_fontsize)

Expand Down
2 changes: 1 addition & 1 deletion bayesflow/diagnostics/plots/mmd_hypothesis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def fill_area_under_kde(kde_object, x_start, x_end=None, **kwargs):

mmd_critical = ops.quantile(mmd_null, 1 - alpha_level)
fill_area_under_kde(
kde, mmd_critical, color=alpha_color, alpha=0.5, label=rf"{int(alpha_level*100)}% rejection area"
kde, mmd_critical, color=alpha_color, alpha=0.5, label=rf"{int(alpha_level * 100)}% rejection area"
)

if truncate_v_lines_at_kde:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def f_teacher(x, t):
ops.cos(t) * ops.sin(t) * self.sigma_data,
)

teacher_output, cos_sin_dFdt = jvp(f_teacher, primals, tangents)
teacher_output, cos_sin_dFdt = jvp(f_teacher, primals, tangents, return_output=True)
teacher_output = ops.stop_gradient(teacher_output)
cos_sin_dFdt = ops.stop_gradient(cos_sin_dFdt)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import keras

from keras.saving import register_keras_serializable as serializable

from bayesflow.types import Tensor
Expand All @@ -24,6 +23,7 @@ def __init__(self, subnet: str | type = "mlp", transform: str = "affine", **kwar

output_projector_kwargs = kwargs.get("output_projector_kwargs", {})
output_projector_kwargs.setdefault("kernel_initializer", "zeros")
output_projector_kwargs.setdefault("bias_initializer", "zeros")
self.output_projector = keras.layers.Dense(units=None, **output_projector_kwargs)

# serialization: store all parameters necessary to call __init__
Expand Down
81 changes: 81 additions & 0 deletions bayesflow/networks/coupling_flow/transforms/_rational_quadratic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import TypedDict

import keras

from bayesflow.types import Tensor


class Edges(TypedDict):
left: Tensor
right: Tensor
bottom: Tensor
top: Tensor


class Derivatives(TypedDict):
left: Tensor
right: Tensor


def _rational_quadratic_spline(
x: Tensor, edges: Edges, derivatives: Derivatives, inverse: bool = False
) -> (Tensor, Tensor):
# rename variables to match the paper:

# $x^{(k)}$
xk = edges["left"]

# $x^{(k+1)}$
xkp = edges["right"]

# $y^{(k)}$
yk = edges["bottom"]

# $y^{(k+1)}$
ykp = edges["top"]

# $delta^{(k)}$
dk = derivatives["left"]

# $delta^{(k+1)}$
dkp = derivatives["right"]

# commonly used values
dx = xkp - xk
dy = ykp - yk
sk = dy / dx

if not inverse:
xi = (x - xk) / dx

# Eq. 4 in the paper
numerator = dy * (sk * xi**2 + dk * xi * (1 - xi))
denominator = sk + (dkp + dk - 2 * sk) * xi * (1 - xi)
result = yk + numerator / denominator
else:
# rename for clarity
y = x

# Eq. 6-8 in the paper
a = dy * (sk - dk) + (y - yk) * (dkp + dk - 2 * sk)
b = dy * dk - (y - yk) * (dkp + dk - 2 * sk)
c = -sk * (y - yk)

# Eq. 29 in the appendix of the paper
discriminant = b**2 - 4 * a * c

# the discriminant must be positive, even when the spline is called out of bounds
discriminant = keras.ops.maximum(discriminant, 0)

xi = 2 * c / (-b - keras.ops.sqrt(discriminant))
result = xi * dx + xk

# Eq 5 in the paper
numerator = sk**2 * (dkp * xi**2 + 2 * sk * xi * (1 - xi) + dk * (1 - xi) ** 2)
denominator = (sk + (dkp + dk - 2 * sk) * xi * (1 - xi)) ** 2
log_jac = keras.ops.log(numerator) - keras.ops.log(denominator)

if inverse:
log_jac = -log_jac

return result, log_jac
Loading
Loading