diff --git a/paths_cli/parameters.py b/paths_cli/parameters.py index e8fbc58..574b17d 100644 --- a/paths_cli/parameters.py +++ b/paths_cli/parameters.py @@ -18,7 +18,29 @@ store='schemes', ) -INIT_CONDS = OPSStorageLoadMultiple( +class InitCondsLoader(OPSStorageLoadMultiple): + def _extract_trajectories(self, obj): + import openpathsampling as paths + if isinstance(obj, paths.SampleSet): + yield from (s.trajectory for s in obj) + elif isinstance(obj, paths.Sample): + yield obj.trajectory + elif isinstance(obj, paths.Trajectory): + yield obj + elif isinstance(obj, list): + for o in obj: + yield from self._extract_trajectories(o) + else: + raise RuntimeError("Unknown initial conditions type: " + f"{obj} (type: {type(obj)}") + + def get(self, storage, names): + results = super().get(storage, names) + final_results = list(self._extract_trajectories(results)) + return final_results + + +INIT_CONDS = InitCondsLoader( param=Option('-t', '--init-conds', multiple=True, help=("identifier for initial conditions " + "(sample set or trajectory)" + HELP_MULTIPLE)), diff --git a/paths_cli/tests/commands/test_equilibrate.py b/paths_cli/tests/commands/test_equilibrate.py index c57a8b8..01c3689 100644 --- a/paths_cli/tests/commands/test_equilibrate.py +++ b/paths_cli/tests/commands/test_equilibrate.py @@ -11,7 +11,7 @@ def print_test(output_storage, scheme, init_conds, multiplier, extra_steps): print(isinstance(output_storage, paths.Storage)) print(scheme.__uuid__) - print(init_conds.__uuid__) + print([o.__uuid__ for o in init_conds]) print(multiplier, extra_steps) @@ -31,8 +31,10 @@ def test_equilibrate(tps_fixture): ["setup.nc", "-o", "foo.nc"] ) out_str = "True\n{schemeid}\n{condsid}\n1 0\n" - expected_output = out_str.format(schemeid=scheme.__uuid__, - condsid=init_conds.__uuid__) + expected_output = out_str.format( + schemeid=scheme.__uuid__, + condsid=[o.trajectory.__uuid__ for o in init_conds], + ) assert results.exit_code == 0 assert results.output == expected_output diff --git a/paths_cli/tests/commands/test_pathsampling.py b/paths_cli/tests/commands/test_pathsampling.py index dc5034a..08d7db3 100644 --- a/paths_cli/tests/commands/test_pathsampling.py +++ b/paths_cli/tests/commands/test_pathsampling.py @@ -11,7 +11,7 @@ def print_test(output_storage, scheme, init_conds, n_steps): print(isinstance(output_storage, paths.Storage)) print(scheme.__uuid__) - print(init_conds.__uuid__) + print([traj.__uuid__ for traj in init_conds]) print(n_steps) @patch('paths_cli.commands.pathsampling.pathsampling_main', print_test) @@ -26,7 +26,8 @@ def test_pathsampling(tps_fixture): results = runner.invoke(pathsampling, ['setup.nc', '-o', 'foo.nc', '-n', '1000']) - expected_output = (f"True\n{scheme.__uuid__}\n{init_conds.__uuid__}" + initcondsid = [samp.trajectory.__uuid__ for samp in init_conds] + expected_output = (f"True\n{scheme.__uuid__}\n{initcondsid}" "\n1000\n") assert results.output == expected_output diff --git a/paths_cli/tests/test_parameters.py b/paths_cli/tests/test_parameters.py index b748fbe..e4cc9f1 100644 --- a/paths_cli/tests/test_parameters.py +++ b/paths_cli/tests/test_parameters.py @@ -214,8 +214,12 @@ def create_file(self, getter): get_type, getter_style = self._parse_getter(getter) main, other = { 'traj': (self.traj, self.other_traj), - 'sset': (self.sample_set, self.other_sample_set) + 'sset': (self.sample_set, self.other_sample_set), + 'samp': (self.sample_set[0], self.other_sample_set[0]), }[get_type] + if get_type == 'samp': + storage.save(main) + storage.save(other) if get_type == 'sset': storage.save(self.sample_set) storage.save(self.other_sample_set) @@ -231,20 +235,23 @@ def create_file(self, getter): if other_tag: storage.tags[other_tag] = other + storage.close() return filename @pytest.mark.parametrize("getter", [ 'name-traj', 'number-traj', 'tag-final-traj', 'tag-initial-traj', - 'name-sset', 'number-sset', 'tag-final-sset', 'tag-initial-sset' + 'name-sset', 'number-sset', 'tag-final-sset', 'tag-initial-sset', + 'name-samp', 'number-samp', ]) def test_get(self, getter): filename = self.create_file(getter) storage = paths.Storage(filename, mode='r') get_type, getter_style = self._parse_getter(getter) expected = { - 'sset': self.sample_set, - 'traj': self.traj + 'sset': [s.trajectory for s in self.sample_set], + 'traj': [self.traj], + 'samp': [self.sample_set[0].trajectory], }[get_type] get_arg = { 'name': 'traj', @@ -277,7 +284,13 @@ def test_get_none(self, num_in_file): st = paths.Storage(filename, mode='r') obj = INIT_CONDS.get(st, None) - assert obj == stored_things[num_in_file - 1] + expected = [ + [self.traj], + [s.trajectory for s in self.sample_set], + [s.trajectory for s in self.other_sample_set], + [s.trajectory for s in self.other_sample_set], + ] + assert obj == expected[num_in_file - 1] def test_get_multiple(self): filename = self.create_file('number-traj') @@ -297,6 +310,18 @@ def test_cannot_guess(self): with pytest.raises(RuntimeError): self.PARAMETER.get(storage, None) + def test_get_bad_name(self): + filename = self._filename("bad_tag") + storage = paths.Storage(filename, 'w') + storage.save(self.traj) + storage.save(self.other_traj) + storage.tags['bad_tag'] = "foo" + storage.close() + + storage = paths.Storage(filename, 'r') + with pytest.raises(RuntimeError, match="initial conditions type"): + self.PARAMETER.get(storage, "bad_tag") + class TestINIT_SNAP(ParamInstanceTest): PARAMETER = INIT_SNAP