Skip to content

Commit

Permalink
Add grad chunk size
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisrothUT committed Sep 25, 2023
1 parent a88e5c8 commit b13b4b1
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 4 deletions.
10 changes: 9 additions & 1 deletion netket/vqs/mc/mc_state/expect_forces_chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def expect_and_forces_nochunking( # noqa: F811
vstate: MCState,
operator: AbstractOperator,
chunk_size: None,
grad_chunk_size: None,
*args,
**kwargs,
):
Expand All @@ -54,6 +55,7 @@ def expect_and_forces_fallback( # noqa: F811
vstate: MCState,
operator: AbstractOperator,
chunk_size: Any,
grad_chunk_size: Any,
*args,
**kwargs,
):
Expand All @@ -72,6 +74,7 @@ def expect_and_forces_impl( # noqa: F811
vstate: MCState,
: AbstractOperator,
chunk_size: int,
grad_chunk_size: int,
*,
mutable: CollectionFilter,
) -> tuple[Stats, PyTree]:
Expand All @@ -81,6 +84,7 @@ def expect_and_forces_impl( # noqa: F811

, Ō_grad, new_model_state = forces_expect_hermitian_chunked(
chunk_size,
grad_chunk_size,
local_estimator_fun,
vstate._apply_fun,
mutable,
Expand All @@ -99,6 +103,7 @@ def expect_and_forces_impl( # noqa: F811
@partial(jax.jit, static_argnums=(0, 1, 2, 3))
def forces_expect_hermitian_chunked(
chunk_size: int,
grad_chunk_size: int,
local_value_kernel_chunked: Callable,
model_apply_fun: Callable,
mutable: CollectionFilter,
Expand All @@ -111,6 +116,9 @@ def forces_expect_hermitian_chunked(
if jnp.ndim(σ) != 2:
σ = σ.reshape((-1, σ_shape[-1]))

if grad_chunk_size is None:
grad_chunk_size = chunk_size

n_samples = σ.shape[0] * mpi.n_nodes

O_loc = local_value_kernel_chunked(
Expand All @@ -134,7 +142,7 @@ def forces_expect_hermitian_chunked(
parameters,
σ,
conjugate=True,
chunk_size=chunk_size,
chunk_size=grad_chunk_size,
chunk_argnums=1,
nondiff_argnums=1,
)
Expand Down
7 changes: 6 additions & 1 deletion netket/vqs/mc/mc_state/expect_grad_chunked.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def expect_and_grad_nochunking( # noqa: F811
operator: AbstractOperator,
use_covariance: Union[Literal[True], Literal[False]],
chunk_size: None,
grad_chunk_size: None,
*args,
**kwargs,
):
Expand All @@ -48,6 +49,7 @@ def expect_and_grad_fallback( # noqa: F811
operator: AbstractOperator,
use_covariance: Union[Literal[True], Literal[False]],
chunk_size: Any,
grad_chunk_size: Any,
*args,
**kwargs,
):
Expand All @@ -67,10 +69,13 @@ def expect_and_grad_covariance_chunked( # noqa: F811
: AbstractOperator,
use_covariance: Literal[True],
chunk_size: int,
grad_chunk_size: int,
*,
mutable: CollectionFilter,
) -> tuple[Stats, PyTree]:
, Ō_grad = expect_and_forces(vstate, , chunk_size, mutable=mutable)
, Ō_grad = expect_and_forces(
vstate, , chunk_size, grad_chunk_size, mutable=mutable
)
Ō_grad = _force_to_grad(Ō_grad, vstate.parameters)
return , Ō_grad

Expand Down
20 changes: 18 additions & 2 deletions netket/vqs/mc/mc_state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ class MCState(VariationalState):

_chunk_size: Optional[int] = None

_grad_chunk_size: Optional[int] = None

def __init__(
self,
sampler: Sampler,
Expand All @@ -141,6 +143,7 @@ def __init__(
n_samples_per_rank: Optional[int] = None,
n_discard_per_chain: Optional[int] = None,
chunk_size: Optional[int] = None,
grad_chunk_size: Optional[int] = None
variables: Optional[PyTree] = None,

Check failure on line 147 in netket/vqs/mc/mc_state/state.py

View workflow job for this annotation

GitHub Actions / Code (ruff)

Ruff (E999)

netket/vqs/mc/mc_state/state.py:147:9: E999 SyntaxError: Unexpected token 'variables'
init_fun: NNInitFunc = None,
apply_fun: Callable = None,
Expand Down Expand Up @@ -436,6 +439,19 @@ def chunk_size(self, chunk_size: Optional[int]):

self._chunk_size = chunk_size

@grad_chunk_size.setter
def grad_chunk_size(self, grad_chunk_size: Optional[int]):
if grad_chunk_size is None:
self._grad_chunk_size = None
return

if grad_chunk_size <= 0:
raise ValueError("Grad chunk size must be a positive integer. ")

if not self.n_samples_per_rank % grad_chunk_size == 0:
raise ValueError("""Grad chunk size must be a divisor of the

Check failure on line 452 in netket/vqs/mc/mc_state/state.py

View workflow job for this annotation

GitHub Actions / Code (ruff)

Ruff (W291)

netket/vqs/mc/mc_state/state.py:452:73: W291 Trailing whitespace
number of samples per rank""")

def reset(self):
"""
Resets the sampled states. This method is called automatically every time
Expand Down Expand Up @@ -597,7 +613,7 @@ def expect_and_grad(
mutable = self.mutable

return expect_and_grad(
self, , use_covariance, self.chunk_size, mutable=mutable
self, , use_covariance, self.chunk_size, self.grad_chunk_size, mutable=mutable
)

# override to use chunks
Expand Down Expand Up @@ -636,7 +652,7 @@ def expect_and_forces(
if mutable is None:
mutable = self.mutable

return expect_and_forces(self, , self.chunk_size, mutable=mutable)
return expect_and_forces(self, , self.chunk_size, self.grad_chunk_size, mutable=mutable)

def quantum_geometric_tensor(
self, qgt_T: Optional[LinearOperator] = None
Expand Down

0 comments on commit b13b4b1

Please sign in to comment.