Skip to content

Commit

Permalink
validate traj before conf selection (#147)
Browse files Browse the repository at this point in the history
Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 3, 2023
1 parent 0563691 commit 1d21697
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
22 changes: 22 additions & 0 deletions dpgen2/op/select_confs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
OPIO,
Artifact,
BigParameter,
FatalError,
OPIOSign,
)

Expand Down Expand Up @@ -79,6 +80,7 @@ def execute(

trajs = ip["trajs"]
model_devis = ip["model_devis"]
trajs, model_devis = SelectConfs.validate_trajs(trajs, model_devis)

confs, report = conf_selector.select(
trajs,
Expand All @@ -92,3 +94,23 @@ def execute(
"confs": confs,
}
)

@staticmethod
def validate_trajs(
trajs,
model_devis,
):
ntrajs = len(trajs)
if ntrajs != len(model_devis):
raise FatalError(
"length of trajs list is not equal to the " "model_devis list"
)
rett = []
retm = []
for tt, mm in zip(trajs, model_devis):
if (tt is None and mm is not None) or (tt is not None and mm is None):
raise FatalError("trajs frame is {tt} while model_devis frame is {mm}")
elif tt is not None and mm is not None:
rett.append(tt)
retm.append(mm)
return rett, retm
23 changes: 23 additions & 0 deletions tests/test_select_confs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
OP,
OPIO,
Artifact,
FatalError,
OPIOSign,
PythonOPTemplate,
)
Expand Down Expand Up @@ -124,3 +125,25 @@ def test(self):
self.assertTrue(confs[1].is_file())
self.assertTrue(confs[0].read_text(), "conf of conf.0")
self.assertTrue(confs[1].read_text(), "conf of conf.1")

def test_validate_trajs(self):
trajs = ["foo", "bar", None, "tar"]
model_devis = ["zar", "par", None, "mar"]
trajs, model_devis = SelectConfs.validate_trajs(trajs, model_devis)
self.assertEqual(trajs, ["foo", "bar", "tar"])
self.assertEqual(model_devis, ["zar", "par", "mar"])

trajs = ["foo", "bar", None, "tar"]
model_devis = ["zar", "par", None]
with self.assertRaises(FatalError) as context:
trajs, model_devis = SelectConfs.validate_trajs(trajs, model_devis)

trajs = ["foo", "bar"]
model_devis = ["zar", None]
with self.assertRaises(FatalError) as context:
trajs, model_devis = SelectConfs.validate_trajs(trajs, model_devis)

trajs = ["foo", None]
model_devis = ["zar", "par"]
with self.assertRaises(FatalError) as context:
trajs, model_devis = SelectConfs.validate_trajs(trajs, model_devis)

0 comments on commit 1d21697

Please sign in to comment.