-
Notifications
You must be signed in to change notification settings - Fork 78
Add Rational Quadratic Spline Transforms to Normalizing Flows #291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
also add padd util
# Conflicts: # bayesflow/utils/__init__.py # bayesflow/utils/jacobian/vjp.py # bayesflow/utils/jacobian_trace/compute_jacobian_trace.py # bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py # bayesflow/utils/jvp.py # environment.yaml
| out_type = "int32" if len(sorted_sequence) <= np.iinfo(np.int32).max else "int64" | ||
|
|
||
| indices = tf.searchsorted(sorted_sequence, values, side=side, out_type=out_type) | ||
| # always use int64 to avoid complicated graph code |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not backend agnostic, as jax forces this kind of behavior. We can still up to int64 for all backends, but this needs to be implemented, still.
|
Still not working in jax because |
…low-org#291) * Splines draft * update keras requirement * small improvements to error messages * add rq spline function * add spline transform * update searchsorted utils for jax also add padd util * update tests * add assert_allclose util for improved messages * parametrize transform for flow tests * update jacobian, jacobian trace, vjp, jvp, and corresponding usages and tests * fix imports, remove old jacobian and jvp, fix application in free form flow * improve logdet computation in free form flows * Fix comparison for symbolic tensors under tf * Add splines to twomoons notebook * improve pad utility * fix missing left edge in spline * fix inside mask edge case * explicitly set bias initializer * add better expand utility * small clean up, renaming * fix indexing, fix inside check * dump * fix sign of log jacobian for inverse pass in rq spline * fix parameter splitting for spline transform * improve readability * fix scale and shift trailing dimension * fix inverse pass return value * correctly choose bins once for each dimension, even for multi-dimensional inputs * run formatter * reduce searchsorted log spam * log backend used at setup * remove maximum message cache size * Improve warning message for jax searchsorted * Fix spline parameter binning for compiled contexts * update inverse transform same as forward * Update TwoMoons notebook with splines WIP [skip ci] * fix spline inverse call for out of bounds values * Add working splines --------- Co-authored-by: stefanradev93 <stefan.radev93@gmail.com>
See title. As a side-effect, this PR includes an update to the interface for backend-agnostic Jacobian computation (and related functions, such as
vjpandjacobian_trace), previously discussed in #251.The set transformer is failing some tests in my local env, but I assume that is unrelated to this PR.
The linter also appears to not be happy with
dict_utils.pywhich is unchanged by this PR.