New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Basic iterator interface for optimization drivers #131
Conversation
Thanks a lot @femtobit, this is excellent work, and the interface is very cool! Just let me add a few comments here, before going into the review:
Just to clarify: if I do something like for st in vmc.iter(10, 2):
obs = st.observables
... are the observables computed at every step or only every 2 steps? (The energy needs anyway to be computed at every step, but maybe we can trigger the calculation of other observables, as added with On a related note, I think we could just remove I understand that this would break the possibility of retrieving the parameters/acceptance doing
I think that, if possible, calling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@femtobit my main comments are about naming, which basically means that the PR is great :)
@@ -31,33 +31,77 @@ namespace netket { | |||
|
|||
void AddGroundStateModule(py::module &m) { | |||
auto subm = m.def_submodule("gs"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that netket.gs.vmc.Vmc
is too deep of a module hierarchy, we should try to simplify it.
netket.Vmc
doesn't look bad to me, but let's discuss more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would make sense to separate the VMC code from the ED code. So netket.vmc
and netket.ed
seem fine to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although strictly speaking ed
is a bit of a misnomer for exact time evolution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can do netket.exact.TimeEvolution
,netket.exact.SparseDiag
, netket.exact.FullDiag
?
It's just that netket.vmc.Vmc
seems redundant to me, since the vmc
module would consist only of the Vmc
class, thus I'd rather do netket.Vmc
, much simpler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, netket.exact
sounds good.
Regarding netket.vmc
: I see your point, however, we might get netket.TdVmc
in the future. (In principle, we could also call it netket.vmc.TimeEvolution
or similar.) Also, there is netket.vmc.Iterator
(which users would usually not see, but has to be somewhere).
py::arg("use_iterative") = false, py::arg("use_cholesky") = true, | ||
py::arg("save_every") = 50) | ||
py::arg("use_iterative") = false, py::arg("use_cholesky") = true) | ||
.def_property_readonly("psi", &VariationalMonteCarlo::GetPsi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency we should name this machine
and not psi
(my fault, I know)
return py::make_iterator(self.begin(), self.end()); | ||
}); | ||
|
||
auto excact = subm.def_submodule("exact"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, we could do netket.ed.ImagTimePropagation
and netket.ed.SparseDiag
etc?
@@ -57,8 +57,8 @@ std::vector<netket::json> GetExactDiagonalizationInputs() { | |||
std::vector<std::vector<double>> szsz = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this file can be just removed
py::arg("hamiltonian"), py::arg("stepper"), py::arg("output_writer"), | ||
py::arg("tmin"), py::arg("tmax"), py::arg("dt")) | ||
.def("add_observable", &ImaginaryTimeDriver::AddObservable, | ||
.def("run", &VariationalMonteCarlo::Run, py::arg("filename_prefix"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe filename_prefix
should be output_prefix
or just output
as before?
mpi_rank = comm.Get_rank() | ||
|
||
for i, st in enumerate(vmc.iter()): | ||
obs = dict(st.observables) # TODO: needs to be called on all MPI processes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to explicitly convert to dict
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently yes, because while the ObsManager
mostly works like a dict
, printing it does not give you all the data.
I think that what could be useful is to do vmc.add_observable(ma.parameters)
vmc.add_observable(sa.acceptance) so that one can chose what optional things to measure. One could even do vmc.add_observable(datetime.datetime.now) that would print the time at which the iteration has been completed, really cool! This should be technically feasible overloading |
Yes, I think this is a valid option.
Looks fun but I'd prefer to keep the observables restricted to actual quantum mechanical observables, i.e., operators and do for step in vmc.iter():
date = datetime.datetime.now()
# ... etc ... What bothers me a bit is that in principle, all the information can be obtained from elsewhere, so there is no real need for the
Good point. I think they are calculated every step right now. I will change that to compute them only when necessary. |
Indeed that's true, that would better be some
True, (to retrieve observables one could also just do But I agree that the iterator is convenient, yes, but maybe just a |
Of course, it is essentially equivalent to something like for i in range(1000):
vmc.advance(step_size)
print(vmc.observables) Maybe we should have implemented the iterators in pure python and just provide the def iter_obs(vmc):
while True:
vmc.advance()
yield vmc.observables()
for obs in iter_obs(vmc):
print(obs) But well, in the end the interface is similar and providing control over the simulation to the surrounding code (which was the main goal) can be achieved either way. For now, I suggest we keep |
Thanks @femtobit , just a reminder to discuss later about adding the observables as a member of vmc, and leave the vmc iterator just as an index counting to advance the iterations |
This is done now: netket/Tutorials/PyNetKet/ground_state_iter.py Lines 40 to 42 in 3ddd9a2
The observable stats are only computed if the method is called. I chose to use this a bit more verbose function name (instead of, say, a property vmc.observables ) because this call actually performs computation and the MPI reduction.
|
This change results in the MPI reduction being performed outside the for-loop body.
The iterator now only dereferences to the step index. All other information can should obtained from the driver or other classes in the iteration loop.
Otherwise running the tests is significantly slower
Merged into feature branch, so that this code can serve as base for other PRs |
Basic iterator interface for optimization drivers
This PR contains a first version of the iterator interface for the VMC (and exact imaginary time) driver, addressing #107.
(The implementation is partially modeled after the description here.)
Basically, it allows to write code such as
In principle, it is possible to modify the state, Hamiltonian, or other objects in between iteration steps (though this is not done anywhere right now).
The Python bindings expose this function as well, allowing, e.g.,
In order to expose
step.observables
, Python bindings for theObsManager
class are provided, with an interface similar to a pythondict
(similar enough to maked = dict(step.observables)
work).Some open issues (or potentially points to discuss):
struct
containing basic information on the state:netket/NetKet/GroundState/variational_montecarlo.hpp
Lines 98 to 103 in a5c6d62
This contains a copy of the data (as opposed to referencing the corresponding internal state of the
VariationalMonteCarlo
class) in order to make it possible to write code such assteps = list(vmc.iter(100))
.This has the downside of copying this data every step, regardless of whether it is needed. For the machine parameters, an option is provided, so that
step.parameters
can be left empty (== nonstd::nullopt
) if storage of the machine parameters is not needed.ObsManager
still performs an MPI reduction which has to be performed on all MPI processes, leading to code such asnetket/Tutorials/PyNetKet/ground_state_iter.py
Lines 40 to 45 in 00a7c60
The code should probably be changed so that
step.observables
does not perform MPI calls. The question is whether it should contain sensible values only on rank 0 or on all MPI processes (the first option should be sufficient but might also lead to annoying-to-debug errors.)VariationalMonteCarlo::Run
function is provided which is now implemented in terms of the iterator interface. This can be used to get essentially the behaviour of NetKet v1:netket/Tutorials/PyNetKet/ground_state.py
Lines 42 to 50 in 00a7c60
Complex = std::complex<double>
(because I think it is cumbersome to always spell-out that type) andIndex
. TheIndex
type is added because right now there are several types used as indices (usuallyint
orstd::size_t
) in different places in the netket codebase. I suggest standardizing onIndex = std::ptrdiff_t
(which is a 64-bit integer on my system). (Whether signed or unsigned types should be used as array indices is a somewhat controversial question in C++, but note that newer additions to the standard such asstd::span
usesptrdiff_t
. This is also in line with what the Google Style Guide has to say about unsigned integers.)Let me know what you think.