Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PFNN class to jax backend #1671

Merged
merged 19 commits into from
Apr 8, 2024
Merged

Add PFNN class to jax backend #1671

merged 19 commits into from
Apr 8, 2024

Conversation

bonneted
Copy link
Contributor

@bonneted bonneted commented Mar 6, 2024

Implementation of PFNN in Jax backend.
There is a slight difference with other backends to allow "rejoin parallel subnetworks after splitting":
A unique shared layer can be applied to each subnetwork after splitting (example : layers = [2, [40] * 5, 40, [40] * 5, 5]
That's why I changed the last layer definition :

#jax
if any(isinstance(unit, (list, tuple)) for unit in self.layer_sizes):
    denses.append([make_dense(1)] * n_output)
else:
    denses.append(make_dense(n_output))

compare to tf:

#tensorflow
if isinstance(layer_sizes[-2], (list, tuple)):  # e.g. [3, 3, 3] -> 3
    self.denses.append(
        [
            tf.keras.layers.Dense(
                1,
                kernel_initializer=initializer,
                kernel_regularizer=self.regularizer,
            )
            for _ in range(n_output)
        ]
    )
else:
    self.denses.append(
        tf.keras.layers.Dense(
            n_output,
            kernel_initializer=initializer,
            kernel_regularizer=self.regularizer,
        )
    )

because currently in tf and pytorch backend, the only possible way of having layer_sizes[-2] not a list is to have no list at all in layer_sizes ("cannot rejoin parallel subnetworks after splitting")

The implementation is tested in the "elasticy_plate.py" example

@bonneted
Copy link
Contributor Author

bonneted commented Mar 7, 2024

I refactored to allow multidimensional output for subnetworks before concatenation :
layer_sizes = [2, [40] * 3, [40] * 3, [1,2,3]] will create a PFNN of output dim 6 = 1+2+3

@lululxvi
Copy link
Owner

lululxvi commented Mar 7, 2024

Refactor the code using https://github.com/psf/black

@lululxvi
Copy link
Owner

Just to confirm, this code supports the same function as tf code, but with additional functions, right?

@bonneted
Copy link
Contributor Author

Do you mean the same features compared to tf ? Yes, it can be called the same way as tf backend and will give the same architecture (consistency with other backends). It has some extra features (allow shared layer after splitting and multi-dimensionnal subnetworks output). I can easily add those features in other backends if you think it's valuable

@lululxvi
Copy link
Owner

Do you mean the same features compared to tf ? Yes, it can be called the same way as tf backend and will give the same architecture (consistency with other backends).

Sounds great.

It has some extra features (allow shared layer after splitting and multi-dimensionnal subnetworks output). I can easily add those features in other backends if you think it's valuable

Yes, feel free to add those features after this PR is merged.

@bonneted
Copy link
Contributor Author

There was a problem with output/input transform in jax.
Because it calculates gradients in a pointwise fashion, the input that passes through the network for the gradient calculation is of size (n_input,) so it generated an error when trying to access x[:,i] or f[:,i] in the input/output transformation.
Adding reshape(1,-1) before and squeeze() after solved it.

I also added hard BC support for elasticity plate example

@lululxvi
Copy link
Owner

There was a problem with output/input transform in jax. Because it calculates gradients in a pointwise fashion, the input that passes through the network for the gradient calculation is of size (n_input,) so it generated an error when trying to access x[:,i] or f[:,i] in the input/output transformation. Adding reshape(1,-1) before and squeeze() after solved it.

I also added hard BC support for elasticity plate example

Good catch. Could you make this a separate PR?

@bonneted
Copy link
Contributor Author

Done, I'll make a new PR once this one is merged.

deepxde/nn/jax/fnn.py Outdated Show resolved Hide resolved
deepxde/nn/jax/fnn.py Outdated Show resolved Hide resolved
deepxde/nn/jax/fnn.py Outdated Show resolved Hide resolved
@bonneted
Copy link
Contributor Author

I also fixed an error introduced in a previous commit (I was applying the activation function on the last layer)

deepxde/nn/jax/fnn.py Outdated Show resolved Hide resolved
deepxde/nn/jax/fnn.py Outdated Show resolved Hide resolved
@bonneted bonneted force-pushed the PFNN-jax branch 2 times, most recently from 6224c69 to 5fb3064 Compare April 6, 2024 09:27
deepxde/nn/jax/fnn.py Outdated Show resolved Hide resolved
@lululxvi lululxvi merged commit ad369e1 into lululxvi:master Apr 8, 2024
11 checks passed
@bonneted bonneted mentioned this pull request Apr 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants