Skip to content

Commit

Permalink
Merge 8487db8 into eb62600
Browse files Browse the repository at this point in the history
  • Loading branch information
dwhswenson committed May 18, 2016
2 parents eb62600 + 8487db8 commit 1262484
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 22 deletions.
76 changes: 57 additions & 19 deletions openpathsampling/analysis/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ def sampling_ensembles(self):
"""
return sum([t.ensembles for t in self.sampling_transitions], [])

@property
def analysis_ensembles(self):
"""
Ensembles from the analysis transitions, excluding special ensembles.
"""
return sum([t.ensembles for t in self.transitions.values()], [])

@property
def all_ensembles(self):
"""
Expand Down Expand Up @@ -75,6 +82,8 @@ class into a simple subclass of this GeneralizedTPSNetwork, which acts
acceptable initial states
final_states : list of :class:`.Volume`
acceptable final states
allow_self_transitions : bool
whether self-transitions (A->A) are allowed; default is False
Attributes
----------
Expand All @@ -83,7 +92,8 @@ class into a simple subclass of this GeneralizedTPSNetwork, which acts
pathlengths.
"""
TransitionType = NotImplemented
def __init__(self, initial_states, final_states, **kwargs):
def __init__(self, initial_states, final_states,
allow_self_transitions=False, **kwargs):
# **kwargs gets passed to the transition
super(GeneralizedTPSNetwork, self).__init__()
try:
Expand All @@ -101,23 +111,40 @@ def __init__(self, initial_states, final_states, **kwargs):
self.final_states = final_states

all_initial = paths.join_volumes(initial_states)
all_initial.name = "|".join([v.name for v in initial_states])
if len(initial_states) > 1:
all_initial.name = "|".join([v.name for v in initial_states])

if set(initial_states) == set(final_states):
if set(initial_states) == set(final_states) or len(final_states) == 1:
all_final = all_initial
else:
all_final = paths.join_volumes(final_states)
all_final.name = "|".join([v.name for v in final_states])
self._sampling_transitions = [
self.TransitionType(all_initial, all_final, **kwargs)
]

self._sampling_transitions = []
for my_initial in initial_states:
my_final_states = [final for final in final_states
if my_initial != final or allow_self_transitions]
my_final = paths.join_volumes(my_final_states)
if len(my_final_states) > 1:
my_final.name = "|".join([v.name for v in my_final_states])
if len(self._sampling_transitions) == 0:
self._sampling_transitions = [
self.TransitionType(my_initial, my_final, **kwargs)
]
elif len(self._sampling_transitions) == 1:
self._sampling_transitions[0].add_transition(my_initial,
my_final)
else:
raise RuntimeError("More than one sampling transition for TPS?")

self.transitions = {
(initial, final) : self.TransitionType(initial, final, **kwargs)
for (initial, final) in itertools.product(initial_states,
final_states)
if initial != final
}


def to_dict(self):
ret_dict = {
'transitions' : self.transitions,
Expand Down Expand Up @@ -169,8 +196,10 @@ def from_state_pairs(cls, state_pairs, **kwargs):


@classmethod
def from_states_all_to_all(cls, states, **kwargs):
return cls(states, states, **kwargs)
def from_states_all_to_all(cls, states, allow_self_transitions=False,
**kwargs):
return cls(states, states,
allow_self_transitions=allow_self_transitions, **kwargs)


class TPSNetwork(GeneralizedTPSNetwork):
Expand All @@ -180,16 +209,20 @@ class TPSNetwork(GeneralizedTPSNetwork):
TransitionType = paths.TPSTransition
# we implement these functions entirely to fix the signature (super's
# version allow arbitrary kwargs) so the documentation can read them.
def __init__(self, initial_states, final_states):
super(TPSNetwork, self).__init__(initial_states, final_states)
def __init__(self, initial_states, final_states,
allow_self_transitions=False):
super(TPSNetwork, self).__init__(initial_states, final_states,
allow_self_transitions)

@classmethod
def from_state_pairs(cls, state_pairs):
def from_state_pairs(cls, state_pairs, allow_self_transitions=False):
return super(TPSNetwork, cls).from_state_pairs(state_pairs)

@classmethod
def from_states_all_to_all(cls, states):
return super(TPSNetwork, cls).from_states_all_to_all(states)
def from_states_all_to_all(cls, states, allow_self_transitions=False):
return super(TPSNetwork, cls).from_states_all_to_all(
states, allow_self_transitions
)


class FixedLengthTPSNetwork(GeneralizedTPSNetwork):
Expand All @@ -201,10 +234,12 @@ class FixedLengthTPSNetwork(GeneralizedTPSNetwork):
# However, without them, we need to explicitly name `length` as
# length=value in these functions. This frees us of that, and gives us
# clearer documentation.
def __init__(self, initial_states, final_states, length):
super(FixedLengthTPSNetwork, self).__init__(initial_states,
final_states,
length=length)
def __init__(self, initial_states, final_states, length,
allow_self_transitions=False):
super(FixedLengthTPSNetwork, self).__init__(
initial_states, final_states,
allow_self_transitions=allow_self_transitions, length=length
)

@classmethod
def from_state_pairs(cls, state_pairs, length):
Expand All @@ -213,9 +248,12 @@ def from_state_pairs(cls, state_pairs, length):
)

@classmethod
def from_states_all_to_all(cls, states, length):
def from_states_all_to_all(cls, states, length,
allow_self_transitions=False):
return super(FixedLengthTPSNetwork, cls).from_states_all_to_all(
states, length=length
states=states,
allow_self_transitions=allow_self_transitions,
length=length
)


Expand Down
4 changes: 2 additions & 2 deletions openpathsampling/analysis/tis_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def _tps_ensemble(self, stateA, stateB):
paths.AllInXEnsemble(stateB) & paths.LengthEnsemble(1)
])

def add_transition(self, stateA, stateB):
new_ens = self._tps_ensemble(stateA, stateB)
def add_transition(self, stateA, stateB, **kwargs):
new_ens = self._tps_ensemble(stateA, stateB, **kwargs)
try:
self.ensembles[0] = self.ensembles[0] | new_ens
except AttributeError:
Expand Down
67 changes: 66 additions & 1 deletion openpathsampling/tests/testnetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,19 @@ def test_autonaming(self):
class testTPSNetwork(object):
def setup(self):
from test_helpers import CallIdentity
xval = CallIdentity()
xval = paths.CV_Function("xval", lambda snap: snap.xyz[0][0])
self.stateA = paths.CVRangeVolume(xval, float("-inf"), -0.5)
self.stateB = paths.CVRangeVolume(xval, -0.1, 0.1)
self.stateC = paths.CVRangeVolume(xval, 0.5, float("inf"))
self.states = [self.stateA, self.stateB, self.stateC]
self.traj = {}
self.traj['AA'] = make_1d_traj([-0.51, -0.49, -0.49, -0.52])
self.traj['AB'] = make_1d_traj([-0.51, -0.25, -0.25, 0.0])
self.traj['BA'] = make_1d_traj([0.0, -0.15, -0.35, -0.52])
self.traj['BB'] = make_1d_traj([0.0, -0.25, 0.25, 0.02])
self.traj['BC'] = make_1d_traj([0.01, 0.16, 0.25, 0.53])
self.traj['CC'] = make_1d_traj([0.51, 0.35, 0.36, 0.55])
self.traj['CA'] = make_1d_traj([0.52, 0.22, -0.22, -0.52])

# define all the test networks as properties: we can do something
# similar then for the fixed path length, and just need to override
Expand Down Expand Up @@ -235,6 +243,34 @@ def test_storage(self):
if os.path.isfile(fname):
os.remove(fname)

def test_allow_self_transitions_false(self):
network = TPSNetwork.from_states_all_to_all(
self.states, allow_self_transitions=False
)
assert_equal(len(network.sampling_ensembles), 1)
ensemble = network.sampling_ensembles[0]
assert_equal(ensemble(self.traj['AA']), False)
assert_equal(ensemble(self.traj['AB']), True)
assert_equal(ensemble(self.traj['BA']), True)
assert_equal(ensemble(self.traj['BC']), True)
assert_equal(ensemble(self.traj['CA']), True)
assert_equal(ensemble(self.traj['BB']), False)
assert_equal(ensemble(self.traj['CC']), False)

def test_allow_self_transitions_true(self):
network = TPSNetwork.from_states_all_to_all(
self.states, allow_self_transitions=True
)
assert_equal(len(network.sampling_ensembles), 1)
ensemble = network.sampling_ensembles[0]
assert_equal(ensemble(self.traj['AA']), True)
assert_equal(ensemble(self.traj['AB']), True)
assert_equal(ensemble(self.traj['BA']), True)
assert_equal(ensemble(self.traj['BC']), True)
assert_equal(ensemble(self.traj['CA']), True)
assert_equal(ensemble(self.traj['BB']), True)
assert_equal(ensemble(self.traj['CC']), True)

class testFixedLengthTPSNetwork(testTPSNetwork):
@property
def network2a(self):
Expand Down Expand Up @@ -280,3 +316,32 @@ def test_lengths(self):
self.network3a, self.network3b, self.network3c]:
assert_equal(network.sampling_transitions[0].length, 10)
assert_equal(network.transitions.values()[0].length, 10)

def test_allow_self_transitions_false(self):
network = FixedLengthTPSNetwork.from_states_all_to_all(
self.states, allow_self_transitions=False, length=4
)
assert_equal(len(network.sampling_ensembles), 1)
ensemble = network.sampling_ensembles[0]
assert_equal(ensemble(self.traj['AA']), False)
assert_equal(ensemble(self.traj['AB']), True)
assert_equal(ensemble(self.traj['BA']), True)
assert_equal(ensemble(self.traj['BC']), True)
assert_equal(ensemble(self.traj['CA']), True)
assert_equal(ensemble(self.traj['BB']), False)
assert_equal(ensemble(self.traj['CC']), False)

def test_allow_self_transitions_true(self):
network = FixedLengthTPSNetwork.from_states_all_to_all(
self.states, allow_self_transitions=True, length=4
)
assert_equal(len(network.sampling_ensembles), 1)
ensemble = network.sampling_ensembles[0]
assert_equal(ensemble(self.traj['AA']), True)
assert_equal(ensemble(self.traj['AB']), True)
assert_equal(ensemble(self.traj['BA']), True)
assert_equal(ensemble(self.traj['BC']), True)
assert_equal(ensemble(self.traj['CA']), True)
assert_equal(ensemble(self.traj['BB']), True)
assert_equal(ensemble(self.traj['CC']), True)

0 comments on commit 1262484

Please sign in to comment.