Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MCState.local_estimators #1179

Merged
merged 4 commits into from
May 4, 2022
Merged

Add MCState.local_estimators #1179

merged 4 commits into from
May 4, 2022

Conversation

femtobit
Copy link
Collaborator

This PR adds an MCState.local_estimators method, addressing point 2 of #1178 (and superseding the previous PR #1154).

O_loc = vstate.local_estimators(operator)
print(O_loc.shape) # -> (n_chains, n_samples)

It only supports computing the estimate for the current vstate.samples. That is a restriction because of the way the interface described in https://netket.readthedocs.io/en/latest/advanced/custom_operators.html allows extending operators in a way that depends on the VariationalState, making it non-trivial to decouple that logic. Doing it this way should ensure that the same code path is used to compute the local estimators that is also used for expect.

@codecov
Copy link

codecov bot commented Apr 28, 2022

Codecov Report

Merging #1179 (fe552a9) into master (6f99f28) will decrease coverage by 0.05%.
The diff coverage is 88.88%.

@@            Coverage Diff             @@
##           master    #1179      +/-   ##
==========================================
- Coverage   82.26%   82.21%   -0.06%     
==========================================
  Files         207      207              
  Lines       12512    12519       +7     
  Branches     1902     1907       +5     
==========================================
- Hits        10293    10292       -1     
- Misses       1776     1779       +3     
- Partials      443      448       +5     
Impacted Files Coverage Δ
netket/vqs/mc/mc_state/state.py 95.65% <88.88%> (-0.52%) ⬇️
netket/operator/_hamiltonian.py 57.05% <0.00%> (-3.33%) ⬇️
netket/experimental/dynamics/_rk_tableau.py 92.50% <0.00%> (ø)
netket/experimental/driver/tdvp.py 87.44% <0.00%> (+0.05%) ⬆️
netket/operator/_local_liouvillian.py 61.53% <0.00%> (+0.21%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 6f99f28...fe552a9. Read the comment docs.

Copy link
Collaborator

@attila-i-szabo attila-i-szabo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as the operators we define go, this looks good to me.

As for user-defined operators, I'm not sure about restricting ourselves to vstate.samples. I mean, it makes sense in context, but for most common kernels, nothing depends on s==vstate.samples. Also, relaxing this would allow implementing a local_estimators for VariationalState largely the same way (probably with some more elaborate dispatch rules). That would be nice, because you can reasonably ask for local estimators on an ExactState (it still has a model with parameters, so there's no lower level to do that), but there isn't a default set of samples for that.

I guess the minimal change would be adding an optional σ argument to get_local_kernel_arguments, but this might be too breaking. Other ideas?

netket/vqs/mc/mc_state/state.py Outdated Show resolved Hide resolved
netket/vqs/mc/mc_state/state.py Outdated Show resolved Hide resolved
Co-authored-by: Attila Szabó <33730178+attila-i-szabo@users.noreply.github.com>
@femtobit
Copy link
Collaborator Author

I guess the minimal change would be adding an optional σ argument to get_local_kernel_arguments, but this might be too breaking. Other ideas?

@attila-i-szabo I agree this would be nice and I think it could potentially be solved with dispatch in a backwards-compatible way: We extend the get_local_kernel_args implementations provided in NetKet to accept an (optional?) samples argument (I think all of them just call s = vstate.samples in their first line anyways) and then provide dispatch rules that still call old implementations which do not accept samples in the case where this is not required.

@attila-i-szabo
Copy link
Collaborator

@PhilipVinc would something like

@dispatch
def get_local_kernel_arguments(vstate: VariationalState, : DiscreteOperator, σ: Optional[Array] = None):

be found by dispatch if you call get_local_kernel_arguments(vstate, O), or do you need two distinct versions?

@PhilipVinc
Copy link
Member

Yes. When you have a default value it gets converted into two functions:

def get_local_kernel_arguments(vstate: VariationalState, Ô: DiscreteOperator, σ: Optional[Array] = None):
    blabla

becomes

def get_local_kernel_arguments(vstate: VariationalState, Ô: DiscreteOperator):
   return get_local_kernel_arguments(vstate, Ô, None)
def get_local_kernel_arguments(vstate: VariationalState, Ô: DiscreteOperator, σ: Optional[Array]):
   blabla

@PhilipVinc
Copy link
Member

I was going through this PR.
So I have the following comments:

As for user-defined operators, I'm not sure about restricting ourselves to vstate.samples.

I think we can restrict ourselves to the purposes of the vstate itself, but it would be useful to start building the primitives own top of which vstate is built in order to support more arbitrary shenanigans like changing the samples.

I mean, it makes sense in context, but for most common kernels, nothing depends on s==vstate.samples.

To give some context, the reason I put the logic taking the samples there is that for the specific combinations of mixed states + Standard Operators the samples must be manipulated before passing them to the kernel.

Also, relaxing this would allow implementing a local_estimators for VariationalState largely the same way (probably with some more elaborate dispatch rules).

I agreee.

I guess the minimal change would be adding an optional σ argument to get_local_kernel_arguments, but this might be too breaking. Other ideas?

The third argument is already taken to be an (optional) int that signals the use of chunking.
You can add another argument that are the samples (which would have type :jnp.ndarray, I think) but you must be careful to add also the combinations of samples + chunking.

--

In general, I'd like the guiding design principle of the internals to be done so that those things could be eventually used in conjunction with jax.vjp by themselves.

@gcarleo
Copy link
Member

gcarleo commented May 2, 2022

my suggestion would be to limit this to vstate.samples for now, as implemented by @femtobit (and limit it to McState type of variational state, since the concept itself of local_estimator is not applicable to exact states)

@femtobit
Copy link
Collaborator Author

femtobit commented May 2, 2022

I agree with @gcarleo.

The extensions we are discussing here can still be done after this PR is merged, as the only user-facing change at that point should be an added optional argument samples to MCState.local_estimators.

In the meantime, we can already use the local estimators for vstate.samples which is the currently most relevant use of this new function.

Any comments on this specific implementation? Or can it be megred? @PhilipVinc

Copy link
Member

@PhilipVinc PhilipVinc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would (for now) not support chunk_size here.
The method is missing a jitting.

netket/vqs/mc/mc_state/state.py Outdated Show resolved Hide resolved
netket/vqs/mc/mc_state/state.py Outdated Show resolved Hide resolved
netket/vqs/mc/mc_state/state.py Outdated Show resolved Hide resolved
netket/vqs/mc/mc_state/state.py Outdated Show resolved Hide resolved
Copy link
Member

@PhilipVinc PhilipVinc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good assuming you un-nest the if statements

@femtobit femtobit merged commit 2a7dded into master May 4, 2022
@femtobit femtobit deleted the local_estimators branch May 4, 2022 20:53
PhilipVinc added a commit that referenced this pull request May 24, 2022
nikosavola pushed a commit to nikosavola/netket that referenced this pull request Jun 4, 2022
* Add MCState.local_estimators

* Pass arbitrary `extra_args`

Co-authored-by: Attila Szabó <33730178+attila-i-szabo@users.noreply.github.com>

* JIT-compile local_estimators

* Implement suggestions from code review
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants