-
Notifications
You must be signed in to change notification settings - Fork 706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PFNN class to jax backend #1671
Conversation
I refactored to allow multidimensional output for subnetworks before concatenation : |
Refactor the code using https://github.com/psf/black |
Just to confirm, this code supports the same function as tf code, but with additional functions, right? |
Do you mean the same features compared to tf ? Yes, it can be called the same way as tf backend and will give the same architecture (consistency with other backends). It has some extra features (allow shared layer after splitting and multi-dimensionnal subnetworks output). I can easily add those features in other backends if you think it's valuable |
Sounds great.
Yes, feel free to add those features after this PR is merged. |
There was a problem with output/input transform in jax. I also added hard BC support for elasticity plate example |
Good catch. Could you make this a separate PR? |
This reverts commit c8f0bbb.
Done, I'll make a new PR once this one is merged. |
I also fixed an error introduced in a previous commit (I was applying the activation function on the last layer) |
6224c69
to
5fb3064
Compare
Implementation of PFNN in Jax backend.
There is a slight difference with other backends to allow "rejoin parallel subnetworks after splitting":
A unique shared layer can be applied to each subnetwork after splitting (example : layers = [2, [40] * 5, 40, [40] * 5, 5]
That's why I changed the last layer definition :
compare to tf:
because currently in tf and pytorch backend, the only possible way of having layer_sizes[-2] not a list is to have no list at all in layer_sizes ("cannot rejoin parallel subnetworks after splitting")
The implementation is tested in the "elasticy_plate.py" example