diff --git a/openmmtools/data/reporter-examples/alanine_dipeptide_legacy.nc b/openmmtools/data/reporter-examples/alanine_dipeptide_legacy.nc new file mode 100644 index 00000000..32d415e8 Binary files /dev/null and b/openmmtools/data/reporter-examples/alanine_dipeptide_legacy.nc differ diff --git a/openmmtools/data/reporter-examples/alanine_dipeptide_legacy_checkpoint.nc b/openmmtools/data/reporter-examples/alanine_dipeptide_legacy_checkpoint.nc new file mode 100644 index 00000000..95175726 Binary files /dev/null and b/openmmtools/data/reporter-examples/alanine_dipeptide_legacy_checkpoint.nc differ diff --git a/openmmtools/multistate/multistatereporter.py b/openmmtools/multistate/multistatereporter.py index baf8e237..9f63456c 100644 --- a/openmmtools/multistate/multistatereporter.py +++ b/openmmtools/multistate/multistatereporter.py @@ -1666,6 +1666,10 @@ def _write_sampler_states_to_given_file(self, sampler_states: list, iteration: i # Store velocites # TODO: This stores velocities as zeros if no velocities are present in the sampler state. Making restored # sampler_state different from origin. + if 'velocities' not in storage.variables: + # create variable with expected dimensions and shape + storage.createVariable('velocities', storage.variables['positions'].dtype, + dimensions=storage.variables['positions'].dimensions) storage.variables['velocities'][write_iteration, :, :, :] = velocities if is_periodic: @@ -1723,9 +1727,14 @@ def _read_sampler_states_from_given_file(self, iteration, storage_file='checkpoi x = storage.variables['positions'][read_iteration, replica_index, :, :].astype(np.float64) positions = unit.Quantity(x, unit.nanometers) - # Restore velocities. - x = storage.variables['velocities'][read_iteration, replica_index, :, :].astype(np.float64) - velocities = unit.Quantity(x, unit.nanometer/unit.picoseconds) + # Restore velocities + # try-catch exception, enabling reading legacy/older serialized objects from openmmtools<0.21.3 + try: + x = storage.variables['velocities'][read_iteration, replica_index, :, :].astype(np.float64) + velocities = unit.Quantity(x, unit.nanometer / unit.picoseconds) + except KeyError: # Velocities key/variable not found in serialization (openmmtools<=0.21.2) + # pass zeros as velocities when key is not found (<0.21.3 behavior) + velocities = np.zeros_like(positions) if 'box_vectors' in storage.variables: # Restore box vectors. diff --git a/openmmtools/tests/test_sampling.py b/openmmtools/tests/test_sampling.py index 553b28b9..3fc353fc 100644 --- a/openmmtools/tests/test_sampling.py +++ b/openmmtools/tests/test_sampling.py @@ -16,14 +16,13 @@ import contextlib import copy import inspect -import math import os import pickle +import shutil import sys from io import StringIO import numpy as np -import scipy.integrate import yaml from nose.plugins.attrib import attr from nose.tools import assert_raises @@ -580,9 +579,8 @@ def test_store_mixing_statistics(self): # TEST MULTISTATE SAMPLERS # ============================================================================== -class TestMultiStateSampler(object): - """Base test suite for the multi-state classes""" - +class TestBaseMultistateSampler(object): + """Minimal Base class to test sampler objects""" # ------------------------------------ # VARIABLES TO SET FOR EACH TEST CLASS # ------------------------------------ @@ -592,6 +590,54 @@ class TestMultiStateSampler(object): SAMPLER = MultiStateSampler REPORTER = MultiStateReporter + @staticmethod + @contextlib.contextmanager + def temporary_storage_path(): + """Generate a storage path in a temporary folder and share it. + + It makes it possible to run tests on multiple nodes with MPI. + + """ + mpicomm = mpiplus.get_mpicomm() + with temporary_directory() as tmp_dir_path: + storage_file_path = os.path.join(tmp_dir_path, 'test_storage.nc') + if mpicomm is not None: + storage_file_path = mpicomm.bcast(storage_file_path, root=0) + yield storage_file_path + + @staticmethod + @contextlib.contextmanager + def captured_output(): + new_out, new_err = StringIO(), StringIO() + old_out, old_err = sys.stdout, sys.stderr + try: + sys.stdout, sys.stderr = new_out, new_err + yield sys.stdout, sys.stderr + finally: + sys.stdout, sys.stderr = old_out, old_err + + @staticmethod + def property_creator(name, on_disk_name, value, on_disk_value): + """ + Helper to create additional properties to create for checking + + Makes a nested dict where the top key is the 'name', with one + value as a dict where the sub-dict is of the form: + {'value': value, + 'on_disk_name': on_disk_name, + 'on_disk_value': on_disk_value + } + """ + return {name: { + + 'value': value, + 'on_disk_value': on_disk_value, + 'on_disk_name': on_disk_name + }} + + +class TestMultiStateSampler(TestBaseMultistateSampler): + """Base test suite for the multi-state classes""" # -------------------------------------- # Optional helper function to overwrite. # -------------------------------------- @@ -707,21 +753,6 @@ def setup_class(cls): print(output_descr) print("#" * len_output) - @staticmethod - @contextlib.contextmanager - def temporary_storage_path(): - """Generate a storage path in a temporary folder and share it. - - It makes it possible to run tests on multiple nodes with MPI. - - """ - mpicomm = mpiplus.get_mpicomm() - with temporary_directory() as tmp_dir_path: - storage_file_path = os.path.join(tmp_dir_path, 'test_storage.nc') - if mpicomm is not None: - storage_file_path = mpicomm.bcast(storage_file_path, root=0) - yield storage_file_path - @staticmethod def get_node_replica_ids(tot_n_replicas): """Return the indices of the replicas that this node is responsible for.""" @@ -731,35 +762,6 @@ def get_node_replica_ids(tot_n_replicas): else: return set(range(mpicomm.rank, tot_n_replicas, mpicomm.size)) - @staticmethod - @contextlib.contextmanager - def captured_output(): - new_out, new_err = StringIO(), StringIO() - old_out, old_err = sys.stdout, sys.stderr - try: - sys.stdout, sys.stderr = new_out, new_err - yield sys.stdout, sys.stderr - finally: - sys.stdout, sys.stderr = old_out, old_err - - @staticmethod - def property_creator(name, on_disk_name, value, on_disk_value): - """ - Helper to create additional properties to create for checking - - Makes a nested dict where the top key is the 'name', with one - value as a dict where the sub-dict is of the form: - {'value': value, - 'on_disk_name': on_disk_name, - 'on_disk_value': on_disk_value - } - """ - return {name: { - - 'value': value, - 'on_disk_value': on_disk_value, - 'on_disk_name': on_disk_name - }} def test_create(self): """Test creation of a new MultiState simulation. @@ -1917,6 +1919,62 @@ def _compute_energies_independently(cls, sampler): return energy_thermodynamic_states, energy_unsampled_states +class TestSerializedMultiStateSampler(TestBaseMultistateSampler): + """ + Test suite for serialized MultiStateSampler objects. + + Requires a different class because serialized objects are not fully compatible between different classes. + """ + + def test_resume_velocities_from_legacy_storage(self): + """ + This tests simulations can be resumed even if velocities are not present in the serialized/reporter file. + + This emulates the behavior of reading older versions (previous to 0.21.3 release) of serialized simulations. + """ + import netCDF4 + origin_reporter_path = testsystems.get_data_filename( + os.path.join("data", "reporter-examples", "alanine_dipeptide_legacy.nc") + ) + origin_checkpoint_path = testsystems.get_data_filename( + os.path.join("data", "reporter-examples", "alanine_dipeptide_legacy_checkpoint.nc") + ) + # Assert no velocities in legacy dataset variables + netcdf_data = netCDF4.Dataset(origin_checkpoint_path) # open checkpoint for reading + assert 'velocities' not in netcdf_data.variables, "velocities variable should not exist in legacy reporter " \ + "netcdf file." + + with self.temporary_storage_path() as storage_path: + # copy files to temporary directory + temporary_checkpoint_path = f"{os.path.splitext(storage_path)[0]}_checkpoint.nc" + reporter_path = shutil.copy(origin_reporter_path, storage_path) # copy reporter file + checkpoint_path = shutil.copy(origin_checkpoint_path, temporary_checkpoint_path) # copy checkpoint file + # Load repex simulation + reporter = self.REPORTER(reporter_path, checkpoint_interval=1) + sampler = self.SAMPLER.from_storage(reporter) + # Assert velocities are initialized as zeros + for state in sampler.sampler_states: + assert np.all(state.velocities.value_in_unit_system(unit.md_unit_system) == 0), \ + "Velocities in sampler state from legacy checkpoint are expected to be all zeros." + + # Resume simulation + sampler.extend(n_iterations=1) + + # delete reporters and load again + del sampler + reporter.close() + # assert velocities variable exist + netcdf_data = netCDF4.Dataset(checkpoint_path) # open checkpoint for reading + assert 'velocities' in netcdf_data.variables, "velocities variable should exist in new reporter " \ + "netcdf file." + netcdf_data.close() # close or it errors in next line + # Load repex simulation from new reporter file + new_sampler = self.SAMPLER.from_storage(reporter) + # assert velocities in sampler states are non-zero + for state in new_sampler.sampler_states: + assert np.any(state.velocities.value_in_unit_system(unit.md_unit_system) != 0), \ + "At least some velocity in sampler state from new checkpoint is expected to different from zero." + # ============================================================================== # MAIN AND TESTS # ==============================================================================