-
Notifications
You must be signed in to change notification settings - Fork 78
Add free-form flows as inference networks #251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
* implements the fff loss * still missing: calculation of the log probability
|
Great idea and much needed! Looking forward to the implementation. :) |
|
Also looking forward to this. Please also check out the already backend-agnostic |
Change `torch.autograd.functional.vjp` to `torch.func.vjp` as the former implementation broke gradient flow. It then also uses the same API as Jax, making the code easier to parse.
Change from `torch.autograd.functional.jvp` to `torch.func.jvp`, as recommended in the documentation. https://pytorch.org/docs/stable/generated/torch.autograd.functional.jvp.html Using autograd.functional seems to break the gradient flow, while `func` does not produce problems in this regard.
|
Thanks for the hint regarding the I also added batched jacobian and jacobian determinant computation. Please check the naming to ensure it fits your desired naming scheme. The implementation of the density is finished as well, so I believe this PR is ready for review. |
Thanks, this is a good change.
I also like this change. I used autograd because I was copying from code I wrote over a year ago (before
These functions produce errors for me under torch backend due to vmap. Please double check. Jacobian computation is also very closely related to Jacobian trace computation, so we could get the fix from how the latter is written, and merge these into one module. The
The inference network itself looks good to me. Should be ready to merge once the tests pass! |
|
@LarsKue Thanks a lot for the review and your changes! The problem arises because when we have conditions, we want to split them into elements for vmap, but when conditions is None, this is not possible. I have for now supplied a somewhat hacky fix, where we use the fact that this implementation of |
|
Looks good so far. I will provide another detailed review asap, and then we are ready to merge :) |
|
@LarsKue bump. |
LarsKue
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thank you for the PR!
| def __init__( | ||
| self, | ||
| beta: float = 50.0, | ||
| encoder_subnet: str | type = "mlp", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type is not serializable out of the box, so we would need a from_config method here. But we can add this later.
| else: | ||
| inp = concatenate(x, conditions, axis=-1) | ||
| network_out = self.encoder_projector( | ||
| self.encoder_subnet(inp, training=training, **kwargs), training=training, **kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer this non-nested for better errors, but again, not a major issue.
|
|
||
| x_out, jac = compute_jacobian(x_in, fn, *func_args, grad_type=grad_type, **func_kwargs) | ||
| jac = ops.reshape( | ||
| jac, (ops.shape(x_in)[0], ops.prod(list(ops.shape(x_out)[1:])), ops.prod(list(ops.shape(x_in)[1:]))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would prefer this in multiple lines
WIP, regards #118.
Free-form flows can be an interesting model class for SBI. In my first trials (without a lot of tuning), I could not achieve the performance of our existing inference networks. Nevertheless, a readily available implementation enables exploration and therefore would be worthwhile to have.