Skip to content

Conversation

@vpratz
Copy link
Collaborator

@vpratz vpratz commented Nov 20, 2024

WIP, regards #118.

Free-form flows can be an interesting model class for SBI. In my first trials (without a lot of tuning), I could not achieve the performance of our existing inference networks. Nevertheless, a readily available implementation enables exploration and therefore would be worthwhile to have.

* implements the fff loss
* still missing: calculation of the log probability
@vpratz vpratz added feature New feature or request v2 labels Nov 20, 2024
@vpratz vpratz self-assigned this Nov 20, 2024
@stefanradev93
Copy link
Contributor

Great idea and much needed! Looking forward to the implementation. :)

@LarsKue
Copy link
Contributor

LarsKue commented Nov 20, 2024

Also looking forward to this. Please also check out the already backend-agnostic make_vjp_fn. I think it would be good to make this a general util, now that it is used in multiple places.

Change `torch.autograd.functional.vjp` to `torch.func.vjp` as the former
implementation broke gradient flow. It then also uses the same API as
Jax, making the code easier to parse.
Change from `torch.autograd.functional.jvp` to
`torch.func.jvp`, as recommended in the documentation.
https://pytorch.org/docs/stable/generated/torch.autograd.functional.jvp.html

Using autograd.functional seems to break the gradient flow, while `func`
does not produce problems in this regard.
@vpratz vpratz changed the title WIP: Add free-form flows as inference networks Add free-form flows as inference networks Nov 21, 2024
@vpratz vpratz marked this pull request as ready for review November 21, 2024 07:52
@vpratz
Copy link
Collaborator Author

vpratz commented Nov 21, 2024

Thanks for the hint regarding the make_vjp_fn, @LarsKue. I have renamed the function to vjp, in analog to torch and jax which provide jvp and vjp with the same functionality, but without the make_fn part. If you prefer the previous naming, let me know and I will change it accordingly.
I also changed the function calls for torch from the autograd.functional API to the func API. Using autograd.functional broke FFF training, probably because the gradients were not left intact. With func, it works without problems. If you see any problems with this change, let me know so that we can discuss how to proceed.

I also added batched jacobian and jacobian determinant computation. Please check the naming to ensure it fits your desired naming scheme.

The implementation of the density is finished as well, so I believe this PR is ready for review.

@LarsKue
Copy link
Contributor

LarsKue commented Nov 21, 2024

I have renamed the function to vjp

Thanks, this is a good change.

I also changed the function calls for torch from the autograd.functional API to the func API.

I also like this change. I used autograd because I was copying from code I wrote over a year ago (before torch.func was included in the base torch installation).

I also added batched jacobian and jacobian determinant computation. Please check the naming to ensure it fits your desired naming scheme.

These functions produce errors for me under torch backend due to vmap. Please double check. Jacobian computation is also very closely related to Jacobian trace computation, so we could get the fix from how the latter is written, and merge these into one module.

The batch_wrap and double_output decorators could also be defined inside compute_jacobian if they survive the fix and are not used anywhere else. This avoids polluting the namespace or suggesting these have further uses.

The implementation of the density is finished as well, so I believe this PR is ready for review.

The inference network itself looks good to me. Should be ready to merge once the tests pass!

@vpratz
Copy link
Collaborator Author

vpratz commented Nov 21, 2024

@LarsKue Thanks a lot for the review and your changes! The problem arises because when we have conditions, we want to split them into elements for vmap, but when conditions is None, this is not possible.

I have for now supplied a somewhat hacky fix, where we use the fact that this implementation of compute_jacobian breaks down the positional arguments into elements and keeps the keyword arguments as they are. Should we allow to supply "splitable" keyword arguments to compute_jacobian as well? If yes, what would be good names to use?

@paul-buerkner paul-buerkner removed the v2 label Nov 21, 2024
@LarsKue
Copy link
Contributor

LarsKue commented Nov 26, 2024

Looks good so far. I will provide another detailed review asap, and then we are ready to merge :)

@stefanradev93
Copy link
Contributor

@LarsKue bump.

Copy link
Contributor

@LarsKue LarsKue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thank you for the PR!

def __init__(
self,
beta: float = 50.0,
encoder_subnet: str | type = "mlp",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type is not serializable out of the box, so we would need a from_config method here. But we can add this later.

else:
inp = concatenate(x, conditions, axis=-1)
network_out = self.encoder_projector(
self.encoder_subnet(inp, training=training, **kwargs), training=training, **kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer this non-nested for better errors, but again, not a major issue.


x_out, jac = compute_jacobian(x_in, fn, *func_args, grad_type=grad_type, **func_kwargs)
jac = ops.reshape(
jac, (ops.shape(x_in)[0], ops.prod(list(ops.shape(x_out)[1:])), ops.prod(list(ops.shape(x_in)[1:])))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer this in multiple lines

@LarsKue LarsKue merged commit 0537f2a into bayesflow-org:dev Dec 2, 2024
13 checks passed
@vpratz vpratz deleted the feat-fff branch March 1, 2025 08:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature or request

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants