Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion paths_cli/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,29 @@
store='schemes',
)

INIT_CONDS = OPSStorageLoadMultiple(
class InitCondsLoader(OPSStorageLoadMultiple):
def _extract_trajectories(self, obj):
import openpathsampling as paths
if isinstance(obj, paths.SampleSet):
yield from (s.trajectory for s in obj)
elif isinstance(obj, paths.Sample):
yield obj.trajectory
elif isinstance(obj, paths.Trajectory):
yield obj
elif isinstance(obj, list):
for o in obj:
yield from self._extract_trajectories(o)
else:
raise RuntimeError("Unknown initial conditions type: "
f"{obj} (type: {type(obj)}")

def get(self, storage, names):
results = super().get(storage, names)
final_results = list(self._extract_trajectories(results))
return final_results


INIT_CONDS = InitCondsLoader(
param=Option('-t', '--init-conds', multiple=True,
help=("identifier for initial conditions "
+ "(sample set or trajectory)" + HELP_MULTIPLE)),
Expand Down
8 changes: 5 additions & 3 deletions paths_cli/tests/commands/test_equilibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def print_test(output_storage, scheme, init_conds, multiplier, extra_steps):
print(isinstance(output_storage, paths.Storage))
print(scheme.__uuid__)
print(init_conds.__uuid__)
print([o.__uuid__ for o in init_conds])
print(multiplier, extra_steps)


Expand All @@ -31,8 +31,10 @@ def test_equilibrate(tps_fixture):
["setup.nc", "-o", "foo.nc"]
)
out_str = "True\n{schemeid}\n{condsid}\n1 0\n"
expected_output = out_str.format(schemeid=scheme.__uuid__,
condsid=init_conds.__uuid__)
expected_output = out_str.format(
schemeid=scheme.__uuid__,
condsid=[o.trajectory.__uuid__ for o in init_conds],
)
assert results.exit_code == 0
assert results.output == expected_output

Expand Down
5 changes: 3 additions & 2 deletions paths_cli/tests/commands/test_pathsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def print_test(output_storage, scheme, init_conds, n_steps):
print(isinstance(output_storage, paths.Storage))
print(scheme.__uuid__)
print(init_conds.__uuid__)
print([traj.__uuid__ for traj in init_conds])
print(n_steps)

@patch('paths_cli.commands.pathsampling.pathsampling_main', print_test)
Expand All @@ -26,7 +26,8 @@ def test_pathsampling(tps_fixture):

results = runner.invoke(pathsampling, ['setup.nc', '-o', 'foo.nc',
'-n', '1000'])
expected_output = (f"True\n{scheme.__uuid__}\n{init_conds.__uuid__}"
initcondsid = [samp.trajectory.__uuid__ for samp in init_conds]
expected_output = (f"True\n{scheme.__uuid__}\n{initcondsid}"
"\n1000\n")

assert results.output == expected_output
Expand Down
35 changes: 30 additions & 5 deletions paths_cli/tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,12 @@ def create_file(self, getter):
get_type, getter_style = self._parse_getter(getter)
main, other = {
'traj': (self.traj, self.other_traj),
'sset': (self.sample_set, self.other_sample_set)
'sset': (self.sample_set, self.other_sample_set),
'samp': (self.sample_set[0], self.other_sample_set[0]),
}[get_type]
if get_type == 'samp':
storage.save(main)
storage.save(other)
if get_type == 'sset':
storage.save(self.sample_set)
storage.save(self.other_sample_set)
Expand All @@ -231,20 +235,23 @@ def create_file(self, getter):

if other_tag:
storage.tags[other_tag] = other

storage.close()
return filename

@pytest.mark.parametrize("getter", [
'name-traj', 'number-traj', 'tag-final-traj', 'tag-initial-traj',
'name-sset', 'number-sset', 'tag-final-sset', 'tag-initial-sset'
'name-sset', 'number-sset', 'tag-final-sset', 'tag-initial-sset',
'name-samp', 'number-samp',
])
def test_get(self, getter):
filename = self.create_file(getter)
storage = paths.Storage(filename, mode='r')
get_type, getter_style = self._parse_getter(getter)
expected = {
'sset': self.sample_set,
'traj': self.traj
'sset': [s.trajectory for s in self.sample_set],
'traj': [self.traj],
'samp': [self.sample_set[0].trajectory],
}[get_type]
get_arg = {
'name': 'traj',
Expand Down Expand Up @@ -277,7 +284,13 @@ def test_get_none(self, num_in_file):

st = paths.Storage(filename, mode='r')
obj = INIT_CONDS.get(st, None)
assert obj == stored_things[num_in_file - 1]
expected = [
[self.traj],
[s.trajectory for s in self.sample_set],
[s.trajectory for s in self.other_sample_set],
[s.trajectory for s in self.other_sample_set],
]
assert obj == expected[num_in_file - 1]

def test_get_multiple(self):
filename = self.create_file('number-traj')
Expand All @@ -297,6 +310,18 @@ def test_cannot_guess(self):
with pytest.raises(RuntimeError):
self.PARAMETER.get(storage, None)

def test_get_bad_name(self):
filename = self._filename("bad_tag")
storage = paths.Storage(filename, 'w')
storage.save(self.traj)
storage.save(self.other_traj)
storage.tags['bad_tag'] = "foo"
storage.close()

storage = paths.Storage(filename, 'r')
with pytest.raises(RuntimeError, match="initial conditions type"):
self.PARAMETER.get(storage, "bad_tag")


class TestINIT_SNAP(ParamInstanceTest):
PARAMETER = INIT_SNAP
Expand Down