Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Squared Operator #1127

Merged
merged 7 commits into from
Mar 1, 2022
Merged

Fix Squared Operator #1127

merged 7 commits into from
Mar 1, 2022

Conversation

PhilipVinc
Copy link
Member

Cherry picked from #1065 so that i separate the two changes in two different PRs.

Computing the gradient of operators that use nkjax.expect instead of the covariance formula (such as SquaredOperator) also had a wrong factor of 2 for C->C models.
This is due to the weird way that Jax handles complex differentiation.

This PR now fixes it, and a test is added to check the gradient wrt finite differences.

Also, this PR moves out into a common test file the finite difference functions so that they can be used for other tests.

@codecov
Copy link

codecov bot commented Feb 28, 2022

Codecov Report

Merging #1127 (3f9ea41) into master (45a7e06) will increase coverage by 0.25%.
The diff coverage is 100.00%.

❗ Current head 3f9ea41 differs from pull request most recent head e6fd7e1. Consider uploading reports for the commit e6fd7e1 to get more accurate results

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1127      +/-   ##
==========================================
+ Coverage   81.66%   81.91%   +0.25%     
==========================================
  Files         207      207              
  Lines       12543    12544       +1     
  Branches     1902     1903       +1     
==========================================
+ Hits        10243    10276      +33     
+ Misses       1854     1825      -29     
+ Partials      446      443       -3     
Impacted Files Coverage Δ
netket/vqs/mc/mc_state/expect_grad.py 84.12% <100.00%> (+0.25%) ⬆️
netket/vqs/mc/mc_state/expect_chunked.py 83.67% <0.00%> (+2.04%) ⬆️
netket/jax/_vjp.py 97.29% <0.00%> (+2.70%) ⬆️
netket/vqs/mc/mc_state/expect.py 96.29% <0.00%> (+9.25%) ⬆️
netket/tools/_cpu_info.py 63.26% <0.00%> (+24.48%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 45a7e06...e6fd7e1. Read the comment docs.

@gcarleo
Copy link
Member

gcarleo commented Feb 28, 2022

cool! was this affecting steady state solvers?

CHANGELOG.md Outdated Show resolved Hide resolved
@PhilipVinc
Copy link
Member Author

Not really, it's just a factor of 2, but I'm trying to get those factors right.

@PhilipVinc
Copy link
Member Author

The steady state does not go through this code path (yet. I'll open another PR to do that, as this code path is simpler, cleaner and faster)

@PhilipVinc PhilipVinc merged commit 52c23ce into master Mar 1, 2022
@PhilipVinc PhilipVinc deleted the pv/sq2 branch March 1, 2022 00:03
nikosavola pushed a commit to nikosavola/netket that referenced this pull request Jun 4, 2022
Cherry picked from netket#1065 so that i separate the two changes in two different PRs.

Computing the gradient of operators that use nkjax.expect instead of the covariance formula (such as SquaredOperator) also had a wrong factor of 2 for C->C models.
This is due to the weird way that Jax handles complex differentiation.

This PR now fixes it, and a test is added to check the gradient wrt finite differences.

Also, this PR moves out into a common test file the finite difference functions so that they can be used for other tests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants