Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore velocities only if found in serialized nc file #613

Merged
merged 7 commits into from Aug 2, 2022
Binary file not shown.
Binary file not shown.
14 changes: 11 additions & 3 deletions openmmtools/multistate/multistatereporter.py
Expand Up @@ -1666,6 +1666,10 @@ def _write_sampler_states_to_given_file(self, sampler_states: list, iteration: i
# Store velocites
# TODO: This stores velocities as zeros if no velocities are present in the sampler state. Making restored
# sampler_state different from origin.
if 'velocities' not in storage.variables:
# create variable with expected dimensions and shape
storage.createVariable('velocities', storage.variables['positions'].dtype,
dimensions=storage.variables['positions'].dimensions)
storage.variables['velocities'][write_iteration, :, :, :] = velocities

if is_periodic:
Expand Down Expand Up @@ -1723,9 +1727,13 @@ def _read_sampler_states_from_given_file(self, iteration, storage_file='checkpoi
x = storage.variables['positions'][read_iteration, replica_index, :, :].astype(np.float64)
positions = unit.Quantity(x, unit.nanometers)

# Restore velocities.
x = storage.variables['velocities'][read_iteration, replica_index, :, :].astype(np.float64)
velocities = unit.Quantity(x, unit.nanometer/unit.picoseconds)
# Restore velocities for recent versions of openmmtools (>0.21.2)
try:
x = storage.variables['velocities'][read_iteration, replica_index, :, :].astype(np.float64)
velocities = unit.Quantity(x, unit.nanometer / unit.picoseconds)
except KeyError: # Velocities key/variable not found in serialization (openmmtools<=0.21.2)
# pass zeros as velocities when key is not found (<0.21.3 behavior)
velocities = np.zeros_like(positions)

if 'box_vectors' in storage.variables:
# Restore box vectors.
Expand Down
51 changes: 49 additions & 2 deletions openmmtools/tests/test_sampling.py
Expand Up @@ -16,14 +16,13 @@
import contextlib
import copy
import inspect
import math
import os
import pickle
import shutil
import sys
from io import StringIO

import numpy as np
import scipy.integrate
import yaml
from nose.plugins.attrib import attr
from nose.tools import assert_raises
Expand Down Expand Up @@ -1307,6 +1306,54 @@ def test_resume_positions_velocities_from_storage(self):
assert np.allclose(original_state.positions, restored_state.positions)
assert np.allclose(original_state.velocities, restored_state.velocities)

def test_resume_velocities_from_legacy_storage(self):
"""
This tests simulations can be resumed even if velocities are not present in the serialized/reporter file.

This emulates the behavior of reading older versions (previous to 0.21.3 release) of serialized simulations.
"""
import netCDF4
origin_reporter_path = testsystems.get_data_filename(
os.path.join("data", "reporter-examples", "alanine_dipeptide_legacy.nc")
)
origin_checkpoint_path = testsystems.get_data_filename(
os.path.join("data", "reporter-examples", "alanine_dipeptide_legacy_checkpoint.nc")
)
# Assert no velocities in legacy dataset variables
netcdf_data = netCDF4.Dataset(origin_checkpoint_path) # open checkpoint for reading
assert 'velocities' not in netcdf_data.variables, "velocities variable should not exist in legacy reporter " \
"netcdf file."

with self.temporary_storage_path() as storage_path:
# copy files to temporary directory
temporary_checkpoint_path = f"{os.path.splitext(storage_path)[0]}_checkpoint.nc"
reporter_path = shutil.copy(origin_reporter_path, storage_path) # copy reporter file
checkpoint_path = shutil.copy(origin_checkpoint_path, temporary_checkpoint_path) # copy checkpoint file
# Load repex simulation
reporter = self.REPORTER(reporter_path, checkpoint_interval=1)
sampler = self.SAMPLER.from_storage(reporter)
# Assert velocities are initialized as zeros
for state in sampler.sampler_states:
assert np.all(state.velocities.value_in_unit_system(unit.md_unit_system) == 0), \
"Velocities in sampler state from legacy checkpoint are expected to be all zeros."

# Resume simulation
sampler.extend(n_iterations=1)

# delete reporters and load again
del sampler
reporter.close()
# assert velocities variable exist
netcdf_data = netCDF4.Dataset(checkpoint_path) # open checkpoint for reading
assert 'velocities' in netcdf_data.variables, "velocities variable should exist in new reporter " \
"netcdf file."
netcdf_data.close() # close or it errors in next line
# Load repex simulation from new reporter file
new_sampler = self.SAMPLER.from_storage(reporter)
# assert velocities in sampler states are non-zero
for state in new_sampler.sampler_states:
assert np.any(state.velocities.value_in_unit_system(unit.md_unit_system) != 0), \
"At least some velocity in sampler state from new checkpoint is expected to different from zero."

def test_last_iteration_functions(self):
"""Test that the last_iteration functions work right"""
Expand Down