Skip to content

Commit

Permalink
Allow overriding the function a MCState is sampled from (#1770)
Browse files Browse the repository at this point in the history
Useful for e.g. implementing a importance sampled state as a subclass.
  • Loading branch information
inailuig committed Apr 23, 2024
1 parent 984cecf commit 49ca287
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions netket/vqs/mc/mc_state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,22 @@ def init(self, seed=None, dtype=None):

@property
def model(self) -> Optional[Any]:
"""Returns the model definition of this variational state.
"""Returns the model definition of this variational state."""
return self._model

This field is optional, and is set to `None` if the variational state has
been initialized using a custom function.
@property
def _sampler_model(self):
"""Returns the model definition used for sampling this variational state.
Equal to `.model`.
"""
return self._model
return self.model

@property
def _sampler_variables(self):
"""Returns the variables used for sampling this variational state.
Equal to `.variables`
"""
return self.variables

@property
def sampler(self) -> Sampler:
Expand All @@ -303,7 +313,7 @@ def sampler(self, sampler: Sampler):

self._sampler = sampler
self.sampler_state = self.sampler.init_state(
self.model, self.variables, seed=self._sampler_seed
self._sampler_model, self._sampler_variables, seed=self._sampler_seed
)
self._sampler_state_previous = self.sampler_state

Expand Down Expand Up @@ -479,7 +489,7 @@ def sample(
self._sampler_state_previous = self.sampler_state

self.sampler_state = self.sampler.reset(
self.model, self.variables, self.sampler_state
self._sampler_model, self._sampler_variables, self.sampler_state
)

if self.n_discard_per_chain > 0:
Expand All @@ -492,8 +502,8 @@ def sample(
)

self._samples, self.sampler_state = self.sampler.sample(
self.model,
self.variables,
self._sampler_model,
self._sampler_variables,
state=self.sampler_state,
chain_length=chain_length,
)
Expand Down

0 comments on commit 49ca287

Please sign in to comment.