Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Jastrow kernel symmetric #644

Merged
merged 3 commits into from
Apr 26, 2021
Merged

Make Jastrow kernel symmetric #644

merged 3 commits into from
Apr 26, 2021

Conversation

PhilipVinc
Copy link
Member

No description provided.

@PhilipVinc PhilipVinc requested a review from gcarleo April 22, 2021 14:47
@gcarleo
Copy link
Member

gcarleo commented Apr 22, 2021

@femtobit or @wdphy16 if you know a better way to only use half of the parameters and work with symmetric matrices it would be amazing

@github-actions
Copy link

Hello and thanks for your Contribution!
I will be building previews of the updated documentation at the following link:
https://netket.github.io/netket/preview/PhilipVinc-patch-1

Once the PR is closed or merged, the preview will be automatically deleted.

@femtobit
Copy link
Collaborator

Should W be symmetric or Hermitian for complex parameters?

@gcarleo
Copy link
Member

gcarleo commented Apr 22, 2021

just symmetric

@codecov-commenter
Copy link

Codecov Report

Merging #644 (9c8be29) into master (8d9909e) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master     #644   +/-   ##
=======================================
  Coverage   67.56%   67.56%           
=======================================
  Files         192      192           
  Lines       10907    10908    +1     
  Branches     1555     1555           
=======================================
+ Hits         7369     7370    +1     
  Misses       3124     3124           
  Partials      414      414           
Impacted Files Coverage Δ
netket/models/jastrow.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8d9909e...9c8be29. Read the comment docs.

@femtobit
Copy link
Collaborator

femtobit commented Apr 23, 2021

@gcarleo It is possible to only store the upper triangular part of W and build the full matrix in __call__. I've tried this here:

class Jastrow(nn.Module):
"""
Jastrow wave function :math:`\Psi(s) = \exp(\sum_{ij} s_i W_{ij} s_j)`.
Note that :math:`W` is a symmetric matrix in this ansatz.
The module parameter :code:`kernel_triu` contains its :math:`N(N+1)/2` independent
entries on the upper triangular part.
"""
input_size: int
"""Size of the input configurations."""
dtype: DType = jnp.complex128
"""The dtype of the weights."""
kernel_init: NNInitFunc = normal()
"""Initializer for the weights."""
def setup(self):
self.n_par = self.input_size * (self.input_size + 1) // 2
self.triu_indices = jnp.triu_indices(self.input_size)
@nn.compact
def __call__(self, x_in: Array):
nv = x_in.shape[-1]
dtype = jnp.promote_types(x_in.dtype, self.dtype)
x_in = jnp.asarray(x_in, dtype=dtype)
params = self.param("kernel_triu", self.kernel_init, (self.n_par,), self.dtype)
kernel = jnp.empty((nv, nv), dtype=self.dtype)
kernel = jax.ops.index_update(
kernel,
self.triu_indices,
params,
)
kernel = kernel + jnp.tril(kernel.T, 1)
y = jnp.einsum("...i,ij,...j", x_in, kernel, x_in)
return y

There is a performance overhead, though it doesn't look too bad (at least for N=32 spins):

# Testing (JITed) Jastrow.apply:
# W + W.T symmetrization
14 µs ± 657 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# triu
18.8 µs ± 695 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

@gcarleo
Copy link
Member

gcarleo commented Apr 23, 2021

nice @femtobit, it's true that there is some small overhead but maybe it is still better when computing gradients, since we only have half the parameters... maybe we can implement your solution instead of the transpose trick?

@PhilipVinc
Copy link
Member Author

Is there a pathological case we could check for which one works best?

@gcarleo
Copy link
Member

gcarleo commented Apr 26, 2021

let's merge this for now and we'll get back to this once more urgent things have been addressed

@gcarleo gcarleo merged commit 2ec82da into master Apr 26, 2021
@gcarleo gcarleo deleted the PhilipVinc-patch-1 branch April 26, 2021 08:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants