Skip to content

Commit

Permalink
handle comments and update test for 100% coverage of added code
Browse files Browse the repository at this point in the history
  • Loading branch information
sroet committed Dec 24, 2020
1 parent 52bc8c7 commit c3eb636
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 12 deletions.
75 changes: 63 additions & 12 deletions openpathsampling/pathmovers/spring_shooting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class SpringShootingSelector(paths.ShootingPointSelector):
spring shooting simulation. It uses a biased potential in the shape of
min(1, e^(-k*i)) for a forward shooting move and min(1, e^(k*i)) for a
backwards shooting move, where i is a frame number in the range
[-delta_max, delta_max] where 0 is the last accepted shooting frame index.
[-delta_max, delta_max] and represents a shift (in frames) relative to
the last accepted shooting frame index.
Parameters
----------
Expand Down Expand Up @@ -40,22 +41,72 @@ def __init__(self, delta_max, k_spring, initial_guess=None):
else:
self.k_spring = k_spring

# Initiate the class variable
# Initiate the instance variable
self.previous_snapshot = initial_guess
self.trial_snapshot = initial_guess
self.previous_trajectory = None

# Make the bias lists
self._fw_prob_list = self._biases(self.delta_max, -self.k_spring)
self._bw_prob_list = self._biases(self.delta_max, self.k_spring)
self._fw_total_bias = self.sum_bias(self.delta_max, -self.k_spring)
self._bw_total_bias = self.sum_bias(self.delta_max, self.k_spring)
self._fw_prob_list = self._spring_biases(self.delta_max,
-self.k_spring)
self._bw_prob_list = self._spring_biases(self.delta_max,
self.k_spring)
self._fw_total_bias = self.sum_spring_bias(self.delta_max,
-self.k_spring)
self._bw_total_bias = self.sum_spring_bias(self.delta_max,
self.k_spring)

# Check if the bias potentials are equal

self._total_bias = self._fw_total_bias
self.check_sanity()

def f(self, snapshot, trajectory, direction=None):

if direction is None:
raise NotImplementedError("f is not defined without a direction.")

if str(direction).lower() not in {'forward', 'backward'}:
raise NotImplementedError("direction must be either 'forward' or "
"'backward'.")
elif direction == 'forward':
prob_list = self._fw_prob_list
else:
# Should be backward
prob_list = self._bw_prob_list
if trajectory is not self.previous_trajectory:
raise NotImplementedError("f is not defined for any other "
"trajectory than "
"self.previous_trajectory.")
if self.previous_snapshot is None:
raise NotImplementedError("f is only defined if a previous index "
"is known.")
previous_shooting_index = self.previous_snapshot
if previous_shooting_index < 0:
previous_shooting_index += len(trajectory)

idx = trajectory.index(snapshot)
diff = (idx-previous_shooting_index)
prob_idx = diff+self.delta_max
if prob_idx < 0 or prob_idx >= len(prob_list):
return 0
else:
return prob_list[prob_idx]

def probability(self, snapshot, trajectory, direction=None):
if direction is None:
raise NotImplementedError("probability is not defined without a "
"direction.")
if str(direction).lower() not in {'forward', 'backward'}:
raise NotImplementedError("direction must be either 'forward' or "
"'backward'.")
elif direction == 'forward':
prob_total = self._fw_total_bias
else:
# Should be backwards
prob_total = self._bw_total_bias
return self.f(snapshot, trajectory, direction=direction)/prob_total

def check_sanity(self):
"""
Checks the sanity of the selector, making sure that de biases are
Expand All @@ -69,10 +120,10 @@ def check_sanity(self):
raise RuntimeError("Sum of the biases changed")

@staticmethod
def _biases(delta_max, k_spring):
def _spring_biases(delta_max, k_spring):
"""
Calculates the list of biases depending on delta_max and k_spring
using the formula min(1, e^(k*i)) where i is in range
Calculates the list of spring biases depending on delta_max and
k_spring using the formula min(1, e^(k*i)) where i is in range
[-delta_max,delta_max]
Parameters
Expand All @@ -90,7 +141,7 @@ def _biases(delta_max, k_spring):
return([min([1, np.exp(k_spring*float(i-delta_max))])
for i in range(2*delta_max+1)])

def sum_bias(self, delta_max, k_spring):
def sum_spring_bias(self, delta_max, k_spring):
"""
Calculates the sum of the biases, given a delta_max and a k_spring.
Expand All @@ -108,9 +159,9 @@ def sum_bias(self, delta_max, k_spring):
"""
# Sum small to big to prevent float summing errors
if k_spring < 0:
return sum(self._biases(delta_max, k_spring)[::-1])
return sum(self._spring_biases(delta_max, k_spring)[::-1])
else:
return sum(self._biases(delta_max, k_spring))
return sum(self._spring_biases(delta_max, k_spring))

def probability_ratio(self, snapshot, initial_trajectory,
trial_trajectory):
Expand Down
84 changes: 84 additions & 0 deletions openpathsampling/tests/test_spring_shooting.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,79 @@ def default_selector(self):
sel._total_bias = 1.0
return sel

def test_f_no_direction(self):
sel = self.default_selector
with pytest.raises(NotImplementedError,
match="direction"):
sel.f(self.mytraj[0], self.mytraj)

def test_f_broken_direction(self):
sel = self.default_selector
with pytest.raises(NotImplementedError,
match="direction"):
sel.f(self.mytraj[0], self.mytraj, direction='broken')

def test_f_no_previous_index(self):
sel = self.default_selector
sel.previous_trajectory = self.mytraj
# Break this to have no initial guess
sel.previous_snapshot = None
with pytest.raises(NotImplementedError,
match="index"):
sel.f(self.mytraj[0], self.mytraj, direction='forward')

def test_f_no_prev_traj(self):
sel = self.default_selector
with pytest.raises(NotImplementedError,
match="self.previous_trajectory"):
sel.f(self.mytraj[0], self.mytraj, direction="forward")

def test_f_wrong_traj(self):
sel = self.default_selector
sel.previous_trajectory = self.mytraj
with pytest.raises(NotImplementedError,
match="self.previous_trajectory"):
# Slice trajectory down to trigger exception
sel.f(self.mytraj[0], self.mytraj[:4], direction="forward")

def test_probability_no_direction(self):
sel = self.default_selector
with pytest.raises(NotImplementedError,
match="direction"):
sel.probability(self.mytraj[0], self.mytraj)

def test_probability_broken_direction(self):
sel = self.default_selector
with pytest.raises(NotImplementedError,
match="direction"):
sel.probability(self.mytraj[0], self.mytraj, direction='broken')

def test_probability_forward(self):
sel = self.default_selector
sel.previous_trajectory = self.mytraj
sel._fw_prob_list = [1.0, 1.0, 0.0]
sel._bw_prob_list = [0.0, 1.0, 1.0]
sel._fw_total_bias = 2.0
sel._total_bias = 2.0
correct = [0.0, 0.0, 0.5, 0.5, 0.0]
for frame, c in zip(self.mytraj, correct):
prob = sel.probability(frame, self.mytraj, 'forward')
assert prob == c

def test_probability_backward(self):
sel = self.default_selector
sel.previous_trajectory = self.mytraj
sel._fw_prob_list = [1.0, 1.0, 0.0]
sel._bw_prob_list = [0.0, 1.0, 1.0]
sel._total_bias = 2.0
sel._bw_total_bias = 2.0
# Set index from 3 to min2 as would normally happen
sel.previous_snapshot = -2
correct = [0.0, 0.0, 0.0, 0.5, 0.5]
for frame, c in zip(self.mytraj, correct):
prob = sel.probability(frame, self.mytraj, 'backward')
assert prob == c

def test_forward_pick(self):
sel = self.default_selector
pick = sel.pick(trajectory=self.mytraj, direction='forward')
Expand Down Expand Up @@ -299,6 +372,17 @@ def test_from_dict(self):
assert len(mover.movers) == 2
assert isinstance(mover.movers[0], SpringMover)
assert isinstance(mover.movers[1], SpringMover)

def test_dict_cycle(self):
old_mover = SpringShootingMover(ensemble=self.ens, delta_max=1,
k_spring=0.0, engine=self.dyn,
initial_guess=3)
dct = old_mover.to_dict()
mover = SpringShootingMover.from_dict(dct)
assert isinstance(mover, SpringShootingMover)
assert len(mover.movers) == 2
assert isinstance(mover.movers[0], SpringMover)
assert isinstance(mover.movers[1], SpringMover)
assert mover.movers[0].selector is mover.movers[1].selector


Expand Down

0 comments on commit c3eb636

Please sign in to comment.