Skip to content

Commit

Permalink
add no_copy option to NoModifier and Deprecate 'subset_mask'
Browse files Browse the repository at this point in the history
  • Loading branch information
sroet committed Oct 6, 2021
1 parent 8a6057e commit 79c961e
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 5 deletions.
8 changes: 8 additions & 0 deletions openpathsampling/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ def version_tuple_to_string(version_tuple):
deprecated_in=(1, 6, 0)
)

NOMODIFICATION_SUBSET_MASK = Deprecation(
problem=("subset_mask is nonsense for NoModification, is ignored in the "
"call, and will be removed as initialisation argument."),
remedy=("You should not use this keyword"),
remove_version=(2, 0),
deprecated_in=(1, 6, 0)
)


# has_deprecations and deprecate hacks to change docstrings inspired by:
# https://stackoverflow.com/a/47441572/4205735
Expand Down
22 changes: 19 additions & 3 deletions openpathsampling/snapshot_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import numpy as np

from openpathsampling.netcdfplus import StorableNamedObject
from openpathsampling.deprecations import SNAPSHOTMODIFIER_PROB_RAT
from openpathsampling.deprecations import (SNAPSHOTMODIFIER_PROB_RAT,
NOMODIFICATION_SUBSET_MASK)
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -113,9 +114,24 @@ def probability_ratio(self, old_snapshot, new_snapshot):


class NoModification(SnapshotModifier):
"""Modifier with no change: returns a copy of the snapshot."""
"""Modifier with no change: "
Parameters
----------
as_copy : bool, default True
if True calls return a copy of the snapshot, else the snapshot itself.
"""
def __init__(self, subset_mask=None, as_copy=True):
# TODO OPS 2.0: subset mask should be removed from this init call
if subset_mask is not None:
NOMODIFICATION_SUBSET_MASK.warn(stacklevel=3)
# masking is nonsense for no modification, but used in testing so this
# is here to conserve API
super(NoModification, self).__init__(subset_mask=subset_mask)
self.as_copy = as_copy

def __call__(self, snapshot):
return snapshot.copy()
return snapshot.copy() if self.as_copy else snapshot

def probability_ratio(self, old_snapshot, new_snapshot):
"""This modifier does not alter the snapshot, so equal probability."""
Expand Down
29 changes: 27 additions & 2 deletions openpathsampling/tests/test_snapshot_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,12 @@ def test_extract_subset(self):
[2.5, 2.6, 2.7]]))

def test_apply_to_subset(self):
# TODO OPS 2.0: this test should be testing SnapshotModifier instead of
# NoModification, but python 2.7 does not allow for the initialisation
# without overwriting the abstract __call__ for now this raises a
# DeprecationWarning
mod = NoModification(subset_mask=[1, 2])

copy_1Dx = self.snapshot_1D.coordinates.copy()
new_1Dx = mod.apply_to_subset(copy_1Dx, np.array([-1.0, -2.0]))
assert_array_almost_equal(new_1Dx, np.array([0.0, -1.0, -2.0, 3.0]))
Expand Down Expand Up @@ -108,6 +113,16 @@ def test_call(self):
assert self.snapshot_1D.velocities is not new_1D.velocities
assert self.snapshot_3D.coordinates is not new_3D.coordinates
assert self.snapshot_3D.velocities is not new_3D.velocities
# TODO OPS 2.0: the following tests should probabily work
# assert new_1D == self.snapshot_1D
# assert new_3D == self.snapshot_3D

def test_call_no_copy(self):
mod = NoModification(as_copy=False)
new_1D = mod(self.snapshot_1D)
assert new_1D is self.snapshot_1D
new_3D = mod(self.snapshot_3D)
assert new_3D is self.snapshot_3D

def test_probability_ratio(self):
# This should always return 1.0 even for invalid input
Expand Down Expand Up @@ -665,10 +680,10 @@ def test_call_with_linear_momentum_fix(self):
np.array([0.0]*3) * u_vel * u_mass)


class TestSnapshotModifierDeprecation(object):
class TestSnapshotModifierDeprecations(object):
# TODO OPS 2.0: Depr should be completed and this test altered to check for
# the error
def test_raise_deprecation(self):
def test_raise_deprecation_prob_ratio(self):
class DummyMod(SnapshotModifier):
# TODO PY 2.7, don't override __call__ for PY 3.x
def __call__(self, a):
Expand All @@ -679,3 +694,13 @@ def __call__(self, a):
assert len(warn) == 1
assert "NotImplementedError" in str(warn[0])
assert a == 1.0

def test_raise_depr_nomodifier_subset(self):
# The warning might be emited before on line 75
# (NoModification(subset_mask))
# Therefor this will not always trigger
pass
# with pytest.warns(DeprecationWarning) as warn:
# _ = NoModification(subset_mask="foo")
# assert len(warn) == 1
# assert "subset_mask" in str(warn[0])

0 comments on commit 79c961e

Please sign in to comment.