From b13b4b1ac8bcd1cccc9138f8f7bb9c1050d5cf93 Mon Sep 17 00:00:00 2001 From: chrisrothUT Date: Mon, 25 Sep 2023 19:09:14 +0000 Subject: [PATCH] Add grad chunk size --- .../vqs/mc/mc_state/expect_forces_chunked.py | 10 +++++++++- netket/vqs/mc/mc_state/expect_grad_chunked.py | 7 ++++++- netket/vqs/mc/mc_state/state.py | 20 +++++++++++++++++-- 3 files changed, 33 insertions(+), 4 deletions(-) diff --git a/netket/vqs/mc/mc_state/expect_forces_chunked.py b/netket/vqs/mc/mc_state/expect_forces_chunked.py index 7da98e9596..c2baafa268 100644 --- a/netket/vqs/mc/mc_state/expect_forces_chunked.py +++ b/netket/vqs/mc/mc_state/expect_forces_chunked.py @@ -42,6 +42,7 @@ def expect_and_forces_nochunking( # noqa: F811 vstate: MCState, operator: AbstractOperator, chunk_size: None, + grad_chunk_size: None, *args, **kwargs, ): @@ -54,6 +55,7 @@ def expect_and_forces_fallback( # noqa: F811 vstate: MCState, operator: AbstractOperator, chunk_size: Any, + grad_chunk_size: Any, *args, **kwargs, ): @@ -72,6 +74,7 @@ def expect_and_forces_impl( # noqa: F811 vstate: MCState, Ô: AbstractOperator, chunk_size: int, + grad_chunk_size: int, *, mutable: CollectionFilter, ) -> tuple[Stats, PyTree]: @@ -81,6 +84,7 @@ def expect_and_forces_impl( # noqa: F811 Ō, Ō_grad, new_model_state = forces_expect_hermitian_chunked( chunk_size, + grad_chunk_size, local_estimator_fun, vstate._apply_fun, mutable, @@ -99,6 +103,7 @@ def expect_and_forces_impl( # noqa: F811 @partial(jax.jit, static_argnums=(0, 1, 2, 3)) def forces_expect_hermitian_chunked( chunk_size: int, + grad_chunk_size: int, local_value_kernel_chunked: Callable, model_apply_fun: Callable, mutable: CollectionFilter, @@ -111,6 +116,9 @@ def forces_expect_hermitian_chunked( if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) + if grad_chunk_size is None: + grad_chunk_size = chunk_size + n_samples = σ.shape[0] * mpi.n_nodes O_loc = local_value_kernel_chunked( @@ -134,7 +142,7 @@ def forces_expect_hermitian_chunked( parameters, σ, conjugate=True, - chunk_size=chunk_size, + chunk_size=grad_chunk_size, chunk_argnums=1, nondiff_argnums=1, ) diff --git a/netket/vqs/mc/mc_state/expect_grad_chunked.py b/netket/vqs/mc/mc_state/expect_grad_chunked.py index a643a82ad9..1775eeeb67 100644 --- a/netket/vqs/mc/mc_state/expect_grad_chunked.py +++ b/netket/vqs/mc/mc_state/expect_grad_chunked.py @@ -35,6 +35,7 @@ def expect_and_grad_nochunking( # noqa: F811 operator: AbstractOperator, use_covariance: Union[Literal[True], Literal[False]], chunk_size: None, + grad_chunk_size: None, *args, **kwargs, ): @@ -48,6 +49,7 @@ def expect_and_grad_fallback( # noqa: F811 operator: AbstractOperator, use_covariance: Union[Literal[True], Literal[False]], chunk_size: Any, + grad_chunk_size: Any, *args, **kwargs, ): @@ -67,10 +69,13 @@ def expect_and_grad_covariance_chunked( # noqa: F811 Ô: AbstractOperator, use_covariance: Literal[True], chunk_size: int, + grad_chunk_size: int, *, mutable: CollectionFilter, ) -> tuple[Stats, PyTree]: - Ō, Ō_grad = expect_and_forces(vstate, Ô, chunk_size, mutable=mutable) + Ō, Ō_grad = expect_and_forces( + vstate, Ô, chunk_size, grad_chunk_size, mutable=mutable + ) Ō_grad = _force_to_grad(Ō_grad, vstate.parameters) return Ō, Ō_grad diff --git a/netket/vqs/mc/mc_state/state.py b/netket/vqs/mc/mc_state/state.py index c0deb1ca87..8ce7e65a2f 100644 --- a/netket/vqs/mc/mc_state/state.py +++ b/netket/vqs/mc/mc_state/state.py @@ -132,6 +132,8 @@ class MCState(VariationalState): _chunk_size: Optional[int] = None + _grad_chunk_size: Optional[int] = None + def __init__( self, sampler: Sampler, @@ -141,6 +143,7 @@ def __init__( n_samples_per_rank: Optional[int] = None, n_discard_per_chain: Optional[int] = None, chunk_size: Optional[int] = None, + grad_chunk_size: Optional[int] = None variables: Optional[PyTree] = None, init_fun: NNInitFunc = None, apply_fun: Callable = None, @@ -436,6 +439,19 @@ def chunk_size(self, chunk_size: Optional[int]): self._chunk_size = chunk_size + @grad_chunk_size.setter + def grad_chunk_size(self, grad_chunk_size: Optional[int]): + if grad_chunk_size is None: + self._grad_chunk_size = None + return + + if grad_chunk_size <= 0: + raise ValueError("Grad chunk size must be a positive integer. ") + + if not self.n_samples_per_rank % grad_chunk_size == 0: + raise ValueError("""Grad chunk size must be a divisor of the + number of samples per rank""") + def reset(self): """ Resets the sampled states. This method is called automatically every time @@ -597,7 +613,7 @@ def expect_and_grad( mutable = self.mutable return expect_and_grad( - self, Ô, use_covariance, self.chunk_size, mutable=mutable + self, Ô, use_covariance, self.chunk_size, self.grad_chunk_size, mutable=mutable ) # override to use chunks @@ -636,7 +652,7 @@ def expect_and_forces( if mutable is None: mutable = self.mutable - return expect_and_forces(self, Ô, self.chunk_size, mutable=mutable) + return expect_and_forces(self, Ô, self.chunk_size, self.grad_chunk_size, mutable=mutable) def quantum_geometric_tensor( self, qgt_T: Optional[LinearOperator] = None