Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replica exchange attempts during equilibration phase #556

Merged
merged 5 commits into from Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 35 additions & 1 deletion openmmtools/multistate/multistatesampler.py
Expand Up @@ -660,13 +660,47 @@ def equilibrate(self, n_iterations, mcmc_moves=None):
raise RuntimeError('The number of MCMCMoves ({}) and ThermodynamicStates ({}) for equilibration'
' must be the same.'.format(len(self._mcmc_moves), self.n_states))

timer = utils.Timer()
timer.start('Run Equilibration')

# Temporarily set the equilibration MCMCMoves.
production_mcmc_moves = self._mcmc_moves
self._mcmc_moves = mcmc_moves
for iteration in range(n_iterations):
for iteration in range(1, 1 + n_iterations):
logger.debug("Equilibration iteration {}/{}".format(iteration, n_iterations))
timer.start('Equilibration Iteration')

# NOTE: Unlike run(), do NOT increment iteration counter.
# self._iteration += 1

# Propagate replicas.
self._propagate_replicas()

# Compute energies of all replicas at all states
self._compute_energies()

# Update thermodynamic states
self._mix_replicas()

# Computing timing information
iteration_time = timer.stop('Equilibration Iteration')
partial_total_time = timer.partial('Run Equilibration')
time_per_iteration = partial_total_time / iteration
estimated_time_remaining = time_per_iteration * (n_iterations - iteration)
estimated_total_time = time_per_iteration * n_iterations
estimated_finish_time = time.time() + estimated_time_remaining
# TODO: Transmit timing information

# Show timing statistics if debug level is activated.
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Iteration took {:.3f}s.".format(iteration_time))
if estimated_time_remaining != float('inf'):
logger.debug("Estimated completion (of equilibration only) in {}, at {} (consuming total wall clock time {}).".format(
str(datetime.timedelta(seconds=estimated_time_remaining)),
time.ctime(estimated_finish_time),
str(datetime.timedelta(seconds=estimated_total_time))))
timer.report_timing()

# Restore production MCMCMoves.
self._mcmc_moves = production_mcmc_moves

Expand Down
10 changes: 6 additions & 4 deletions openmmtools/multistate/sams.py
Expand Up @@ -426,11 +426,13 @@ def _mix_replicas(self):
logger.debug("Accepted {}/{} attempted swaps ({:.1f}%)".format(n_swaps_accepted, n_swaps_proposed,
swap_fraction_accepted * 100.0))

# Update logZ estimates
self._update_logZ_estimates(replicas_log_P_k)
# Do not update and/or write to disk during equilibration
if self._iteration > 0:
# Update logZ estimates
self._update_logZ_estimates(replicas_log_P_k)

# Update log weights based on target probabilities
self._update_log_weights()
# Update log weights based on target probabilities
self._update_log_weights()

def _local_jump(self, replicas_log_P_k):
n_replica, n_states, locality = self.n_replicas, self.n_states, self.locality
Expand Down
4 changes: 2 additions & 2 deletions openmmtools/tests/test_sampling.py
Expand Up @@ -1178,8 +1178,8 @@ def test_equilibrate(self):
if len(node_replica_ids) == n_replicas:
reporter = self.REPORTER(storage_path, open_mode='r', checkpoint_interval=1)
stored_sampler_states = reporter.read_sampler_states(iteration=0)
for new_state, stored_state in zip(sampler._sampler_states, stored_sampler_states):
assert np.allclose(new_state.positions, stored_state.positions)
for stored_state in stored_sampler_states:
assert any([np.allclose(new_state.positions, stored_state.positions) for new_state in sampler._sampler_states])

# We are still at iteration 0.
assert sampler._iteration == 0
Expand Down