Skip to content

Conversation

@LarsKue
Copy link
Contributor

@LarsKue LarsKue commented Jan 14, 2025

See title. As a side-effect, this PR includes an update to the interface for backend-agnostic Jacobian computation (and related functions, such as vjp and jacobian_trace), previously discussed in #251.

The set transformer is failing some tests in my local env, but I assume that is unrelated to this PR.
The linter also appears to not be happy with dict_utils.py which is unchanged by this PR.

stefanradev93 and others added 11 commits November 28, 2024 11:10
# Conflicts:
#	bayesflow/utils/__init__.py
#	bayesflow/utils/jacobian/vjp.py
#	bayesflow/utils/jacobian_trace/compute_jacobian_trace.py
#	bayesflow/utils/jacobian_trace/estimate_jacobian_trace.py
#	bayesflow/utils/jvp.py
#	environment.yaml
@LarsKue LarsKue added the feature New feature or request label Jan 14, 2025
@LarsKue LarsKue self-assigned this Jan 14, 2025
@LarsKue LarsKue added this to the BayesFlow 2.0 milestone Jan 14, 2025
@LarsKue LarsKue requested a review from stefanradev93 January 14, 2025 13:56
out_type = "int32" if len(sorted_sequence) <= np.iinfo(np.int32).max else "int64"

indices = tf.searchsorted(sorted_sequence, values, side=side, out_type=out_type)
# always use int64 to avoid complicated graph code
Copy link
Contributor Author

@LarsKue LarsKue Jan 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not backend agnostic, as jax forces this kind of behavior. We can still up to int64 for all backends, but this needs to be implemented, still.

@LarsKue
Copy link
Contributor Author

LarsKue commented Jan 23, 2025

Still not working in jax because nonzero is not supported for compiled contexts. The only fix is to compute the spline on everything and throw away the garbage entries. I am working on it.

@stefanradev93 stefanradev93 merged commit 146f050 into dev Jan 28, 2025
10 of 13 checks passed
@stefanradev93 stefanradev93 deleted the splines branch January 28, 2025 05:57
han-ol pushed a commit to han-ol/bayesflow that referenced this pull request Jan 28, 2025
…low-org#291)

* Splines draft

* update keras requirement

* small improvements to error messages

* add rq spline function

* add spline transform

* update searchsorted utils for jax
also add padd util

* update tests

* add assert_allclose util for improved messages

* parametrize transform for flow tests

* update jacobian, jacobian trace, vjp, jvp, and corresponding usages and tests

* fix imports, remove old jacobian and jvp, fix application in free form flow

* improve logdet computation in free form flows

* Fix comparison for symbolic tensors under tf

* Add splines to twomoons notebook

* improve pad utility

* fix missing left edge in spline

* fix inside mask edge case

* explicitly set bias initializer

* add better expand utility

* small clean up, renaming

* fix indexing, fix inside check

* dump

* fix sign of log jacobian for inverse pass in rq spline

* fix parameter splitting for spline transform

* improve readability

* fix scale and shift trailing dimension

* fix inverse pass return value

* correctly choose bins once for each dimension, even for multi-dimensional inputs

* run formatter

* reduce searchsorted log spam

* log backend used at setup

* remove maximum message cache size

* Improve warning message for jax searchsorted

* Fix spline parameter binning for compiled contexts

* update inverse transform same as forward

* Update TwoMoons notebook with splines WIP [skip ci]

* fix spline inverse call for out of bounds values

* Add working splines

---------

Co-authored-by: stefanradev93 <stefan.radev93@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants