diff --git a/openmmtools/multistate/multistatesampler.py b/openmmtools/multistate/multistatesampler.py index 37d8588c..55b8e3c0 100644 --- a/openmmtools/multistate/multistatesampler.py +++ b/openmmtools/multistate/multistatesampler.py @@ -660,13 +660,47 @@ def equilibrate(self, n_iterations, mcmc_moves=None): raise RuntimeError('The number of MCMCMoves ({}) and ThermodynamicStates ({}) for equilibration' ' must be the same.'.format(len(self._mcmc_moves), self.n_states)) + timer = utils.Timer() + timer.start('Run Equilibration') + # Temporarily set the equilibration MCMCMoves. production_mcmc_moves = self._mcmc_moves self._mcmc_moves = mcmc_moves - for iteration in range(n_iterations): + for iteration in range(1, 1 + n_iterations): logger.debug("Equilibration iteration {}/{}".format(iteration, n_iterations)) + timer.start('Equilibration Iteration') + + # NOTE: Unlike run(), do NOT increment iteration counter. + # self._iteration += 1 + + # Propagate replicas. self._propagate_replicas() + # Compute energies of all replicas at all states + self._compute_energies() + + # Update thermodynamic states + self._mix_replicas() + + # Computing timing information + iteration_time = timer.stop('Equilibration Iteration') + partial_total_time = timer.partial('Run Equilibration') + time_per_iteration = partial_total_time / iteration + estimated_time_remaining = time_per_iteration * (n_iterations - iteration) + estimated_total_time = time_per_iteration * n_iterations + estimated_finish_time = time.time() + estimated_time_remaining + # TODO: Transmit timing information + + # Show timing statistics if debug level is activated. + if logger.isEnabledFor(logging.DEBUG): + logger.debug("Iteration took {:.3f}s.".format(iteration_time)) + if estimated_time_remaining != float('inf'): + logger.debug("Estimated completion (of equilibration only) in {}, at {} (consuming total wall clock time {}).".format( + str(datetime.timedelta(seconds=estimated_time_remaining)), + time.ctime(estimated_finish_time), + str(datetime.timedelta(seconds=estimated_total_time)))) + timer.report_timing() + # Restore production MCMCMoves. self._mcmc_moves = production_mcmc_moves diff --git a/openmmtools/multistate/sams.py b/openmmtools/multistate/sams.py index b1417412..b39fad1f 100644 --- a/openmmtools/multistate/sams.py +++ b/openmmtools/multistate/sams.py @@ -426,11 +426,13 @@ def _mix_replicas(self): logger.debug("Accepted {}/{} attempted swaps ({:.1f}%)".format(n_swaps_accepted, n_swaps_proposed, swap_fraction_accepted * 100.0)) - # Update logZ estimates - self._update_logZ_estimates(replicas_log_P_k) + # Do not update and/or write to disk during equilibration + if self._iteration > 0: + # Update logZ estimates + self._update_logZ_estimates(replicas_log_P_k) - # Update log weights based on target probabilities - self._update_log_weights() + # Update log weights based on target probabilities + self._update_log_weights() def _local_jump(self, replicas_log_P_k): n_replica, n_states, locality = self.n_replicas, self.n_states, self.locality diff --git a/openmmtools/tests/test_sampling.py b/openmmtools/tests/test_sampling.py index 821b739b..631da5c0 100644 --- a/openmmtools/tests/test_sampling.py +++ b/openmmtools/tests/test_sampling.py @@ -1178,8 +1178,8 @@ def test_equilibrate(self): if len(node_replica_ids) == n_replicas: reporter = self.REPORTER(storage_path, open_mode='r', checkpoint_interval=1) stored_sampler_states = reporter.read_sampler_states(iteration=0) - for new_state, stored_state in zip(sampler._sampler_states, stored_sampler_states): - assert np.allclose(new_state.positions, stored_state.positions) + for stored_state in stored_sampler_states: + assert any([np.allclose(new_state.positions, stored_state.positions) for new_state in sampler._sampler_states]) # We are still at iteration 0. assert sampler._iteration == 0