Skip to content

Commit

Permalink
Fix #1643 : support complex parameters in SRt (#1644)
Browse files Browse the repository at this point in the history
fixes #1643

Apart from me fixing some docstrings in order to figure out how to fix
this issue... The only thing was re-concatenating parameters (and
dividing by 2)
  • Loading branch information
PhilipVinc committed Nov 8, 2023
1 parent 5dfea3d commit 257532d
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 26 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
## NetKet 3.11 (⚙️ In development)


## NetKet 3.10.1 (8 november 2023)

## NetKet 3.10 (🥶 8 november 2023)
### Bug Fixes
* Added support for neural networks with complex parameters to {class}`netket.experimental.driver.VMC_SRt`, which was just crashing with unreadable errors before [#1644](https://github.com/netket/netket/pull/1644).

## NetKet 3.10 (🥶 7 november 2023)

The highlights of this version are a new experimental driver to optimise networks with millions of parameters using SR, and introduces new utility functions to convert a pyscf molecule to a netket Hamiltonian.

Expand Down
44 changes: 33 additions & 11 deletions netket/experimental/driver/vmc_srt.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@


@partial(jax.jit, static_argnames=("mode", "solver_fn"))
def SRt(O_L, local_energies, diag_shift, *, mode, solver_fn, e_mean=None):
def SRt(
O_L, local_energies, diag_shift, *, mode, solver_fn, e_mean=None, params_structure
):
"""
For more details, see `https://arxiv.org/abs/2310.05715'. In particular,
the following parallel implementation is described in Appendix "Distributed SR computation".
Expand All @@ -41,7 +43,10 @@ def SRt(O_L, local_energies, diag_shift, *, mode, solver_fn, e_mean=None):
dv = -2.0 * de / N_mc**0.5

if mode == "complex":
O_L = jnp.concatenate((O_L[:, 0], O_L[:, 1]), axis=0)
# Concatenate the real and imaginary derivatives of the ansatz
# O_L = jnp.concatenate((O_L[:, 0], O_L[:, 1]), axis=0)
O_L = jnp.transpose(O_L, (1, 0, 2)).reshape(-1, O_L.shape[-1])

dv = jnp.concatenate((jnp.real(dv), -jnp.imag(dv)), axis=-1)
elif mode == "real":
dv = dv.real
Expand Down Expand Up @@ -77,6 +82,12 @@ def SRt(O_L, local_energies, diag_shift, *, mode, solver_fn, e_mean=None):
updates = O_L.T @ aus_vector
updates, token = mpi.mpi_allreduce_sum_jax(updates, token=token)

# If complex mode and we have complex parameters, we need
# To repack the real coefficients in order to get complex updates
if mode == "complex" and nkjax.tree_leaf_iscomplex(params_structure):
np = updates.shape[-1] // 2
updates = (updates[:np] + 1j * updates[np:]) / 2

return -updates


Expand Down Expand Up @@ -143,14 +154,6 @@ def __init__(
"""
super().__init__(hamiltonian, optimizer, variational_state=variational_state)

self._ham = hamiltonian.collect() # type: AbstractOperator
self.diag_shift = diag_shift
self.jacobian_mode = jacobian_mode
self._linear_solver_fn = linear_solver_fn

_, unravel_params_fn = ravel_pytree(self.state.parameters)
self._unravel_params_fn = jax.jit(unravel_params_fn)

if self.state.n_parameters % mpi.n_nodes != 0:
raise NotImplementedError(
f"""
Expand All @@ -173,6 +176,24 @@ def __init__(
stacklevel=2,
)

self._ham = hamiltonian.collect() # type: AbstractOperator
self.diag_shift = diag_shift
self.jacobian_mode = jacobian_mode
self._linear_solver_fn = linear_solver_fn

self._params_structure = jax.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype), self.state.parameters
)
if not nkjax.tree_ishomogeneous(self._params_structure):
raise ValueError(
"SRt only supports neural networks with all real or all complex "
"parameters. Hybrid structures are not yet supported (but we would welcome "
"contributions. Get in touch with us!)"
)

_, unravel_params_fn = ravel_pytree(self.state.parameters)
self._unravel_params_fn = jax.jit(unravel_params_fn)

@property
def jacobian_mode(self) -> str:
"""
Expand Down Expand Up @@ -228,7 +249,7 @@ def _forward_and_backward(self):
mode=self.jacobian_mode,
dense=True,
center=True,
) # * jaxcobians is centered
) # jacobians is centered

diag_shift = self.diag_shift
if callable(self.diag_shift):
Expand All @@ -241,6 +262,7 @@ def _forward_and_backward(self):
mode=self.jacobian_mode,
solver_fn=self._linear_solver_fn,
e_mean=self._loss_stats.Mean,
params_structure=self._params_structure,
)

self._dp = self._unravel_params_fn(updates)
Expand Down
4 changes: 2 additions & 2 deletions netket/jax/_jacobian/default_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __eq__(self, o):
HolomorphicMode = JacobianMode("holomorphic")


@partial(jax.jit, static_argnames=("apply_fun", "holomorphic"))
@partial(jax.jit, static_argnames=("apply_fun", "holomorphic", "warn"))
def jacobian_default_mode(
apply_fun: Callable[[PyTree, Array], Array],
pars: PyTree,
Expand All @@ -67,7 +67,7 @@ def jacobian_default_mode(
warn: bool = True,
) -> JacobianMode:
"""
Returns the default `mode` for {func}`nk.jax.jacobian` given a certain
Returns the default `mode` for {func}`netket.jax.jacobian` given a certain
wave-function ansatz.
This function uses an abstract evaluation of the ansatz to determine if
Expand Down
45 changes: 35 additions & 10 deletions netket/jax/_jacobian/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def jacobian(
read the detailed discussion below.
pdf: Optional coefficient that is used to multiply every row of the Jacobian.
When performing calculations in full-summation, this can be used to
multiply every row by :math:`\abs{\psi(\sigma)}^2`, which is needed to
multiply every row by :math:`|\psi(\sigma)|^2`, which is needed to
compute the correct average.
chunk_size: Optional integer specifying the maximum number of samples for
which the gradient is simulataneously computed. Low-values will
Expand Down Expand Up @@ -187,11 +187,12 @@ def jacobian(
.. math::
O_k(\sigma) = \frac{\partial \ln\Re[\Psi(\sigma)]}{\partial \Re[\theta_k]}
O^{r}_k(\sigma) = \frac{\partial \ln\Re[\Psi(\sigma)]}{\partial \theta_k}
\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,
O_k(\sigma) = \frac{\partial \ln\Im[\Psi(\sigma)]}{\partial \Re[\theta_k]}
O^{i}_k(\sigma) = \frac{\partial \ln\Im[\Psi(\sigma)]}{\partial \theta_k}
properly concatenated in a single PyTree for every set of parameters. In practice,
where :math:`O^{r}_k(\sigma)` and :math:`O^{i}_k(\sigma)` are real-valued pytrees
with the same shape as the original parameters. In practice,
it should return a result roughly equivalent to the following listing:
.. code:: python
Expand All @@ -207,17 +208,41 @@ def jacobian(
do this for performance reason, but the downstream user is free to do it if
he wishes.
**If some parameters** :math:`\theta_k` **are complex**, this mode returns the
derivatives of the real and imaginary part of the function,
If you wish to get the complex jacobian in the case of real parameters, it is
possible to define
.. math::
O_k(\sigma) = \frac{\partial \ln\Re[\Psi(\sigma)]}{\partial \Re[\theta_k]}
O_k(\sigma) = O^{r}_k(\sigma) + i O^{i}_k(\sigma)
which is now complex-valued. In code, this is equivalent to
.. code:: python
O_k_cmplx = jax.tree_map(lambda jri: jri[:, 0, :] + 1j* jri[:, 1, :], O_k)
**If some parameters** :math:`\theta_k` **are complex**, this mode splits the
:math:`N` complex parameters into :math:`2N` real parameters, where the first
block of :math:`N` parameters correspond to the real parts and the latter block
to the imaginary part, and then follows the logic discussed above.
In formulas, this can be seen as defining the vector of :math:`2N` real parameters
.. math::
\tilde\theta = (\Re[\theta], \Im[\theta])
and then computing the same quantities as above
.. math::
O^{r}_k(\sigma) = \frac{\partial \ln\Re[\Psi(\sigma)]}{\partial \tilde\theta_k]}
\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,\,
O_k(\sigma) = \frac{\partial \ln\Im[\Psi(\sigma)]}{\partial \Re[\theta_k]}
O^{i}_k(\sigma) = \frac{\partial \ln\Im[\Psi(\sigma)]}{\partial \tilde\theta_k]}
properly concatenated in a single PyTree for every set of parameters. In practice,
it should return a result roughly equivalent to the following listing:
where now those objects have twice the number of elements as the parameters.
In practice, it should return a result roughly equivalent to the following listing:
.. code:: python
Expand Down
43 changes: 41 additions & 2 deletions test/driver/test_srt.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __call__(self, x):
return jnp.sum(x, axis=-1).astype(jnp.complex128)


def _setup(complex=True):
def _setup(*, complex=True, machine=None):
L = 4
Ns = L * L
lattice = nk.graph.Square(L, max_neighbor_order=2)
Expand All @@ -73,7 +73,8 @@ def _setup(complex=True):
)

# Define a variational state
machine = RBM(num_hidden=2 * Ns, complex=complex)
if machine is None:
machine = RBM(num_hidden=2 * Ns, complex=complex)

sampler = nk.sampler.MetropolisExchange(
hilbert=hi,
Expand All @@ -94,6 +95,44 @@ def _setup(complex=True):
return H, opt, vstate


def test_SRt_vs_linear_solver_complexpars():
"""
nk.driver.VMC_kernelSR must give **exactly** the same dynamics as nk.driver.VMC with nk.optimizer.SR
"""
n_iters = 5

model = nk.models.RBM(
param_dtype=jnp.complex128,
kernel_init=jax.nn.initializers.normal(stddev=0.02),
hidden_bias_init=jax.nn.initializers.normal(stddev=0.02),
use_visible_bias=False,
)

H, opt, vstate_srt = _setup(machine=model)
gs = VMC_SRt(
H, opt, variational_state=vstate_srt, diag_shift=0.1, jacobian_mode="complex"
)
logger_srt = nk.logging.RuntimeLog()
gs.run(n_iter=n_iters, out=logger_srt)

H, opt, vstate_sr = _setup(machine=model)
sr = nk.optimizer.SR(solver=solve, diag_shift=0.1, holomorphic=False)
gs = nk.driver.VMC(H, opt, variational_state=vstate_sr, preconditioner=sr)
logger_sr = nk.logging.RuntimeLog()
gs.run(n_iter=n_iters, out=logger_sr)

# check same parameters
jax.tree_map(
np.testing.assert_allclose, vstate_srt.parameters, vstate_sr.parameters
)

if mpi.rank == 0:
energy_kernelSR = logger_srt.data["Energy"]["Mean"]
energy_SR = logger_sr.data["Energy"]["Mean"]

np.testing.assert_allclose(energy_kernelSR, energy_SR, atol=1e-10)


def test_SRt_vs_linear_solver():
"""
nk.driver.VMC_kernelSR must give **exactly** the same dynamics as nk.driver.VMC with nk.optimizer.SR
Expand Down

0 comments on commit 257532d

Please sign in to comment.