Skip to content

Commit

Permalink
Merge pull request #187 from wanghan-iapcm/add-abs-method
Browse files Browse the repository at this point in the history
refactorize ExplorationTaskGroup.
  • Loading branch information
zjgemi committed Feb 2, 2024
2 parents 0f843bf + 74873cb commit c0f1973
Show file tree
Hide file tree
Showing 11 changed files with 141 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ConfSelector,
)
from dpgen2.exploration.task import (
BaseExplorationTaskGroup,
ExplorationStage,
ExplorationTaskGroup,
)
Expand Down Expand Up @@ -67,7 +68,7 @@ def plan_next_iteration(
self,
report: Optional[ExplorationReport] = None,
trajs: Optional[List[Path]] = None,
) -> Tuple[bool, Optional[ExplorationTaskGroup], Optional[ConfSelector]]:
) -> Tuple[bool, Optional[BaseExplorationTaskGroup], Optional[ConfSelector]]:
if self.complete():
raise FatalError("Cannot plan because the stage has completed.")
if report is None:
Expand Down
3 changes: 3 additions & 0 deletions dpgen2/exploration/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,8 @@
)
from .task import (
ExplorationTask,
)
from .task_group import (
BaseExplorationTaskGroup,
ExplorationTaskGroup,
)
2 changes: 2 additions & 0 deletions dpgen2/exploration/task/conf_sampling_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from .task import (
ExplorationTask,
)
from .task_group import (
ExplorationTaskGroup,
)

Expand Down
5 changes: 2 additions & 3 deletions dpgen2/exploration/task/customized_lmp_template_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
)
from .task import (
ExplorationTask,
ExplorationTaskGroup,
)


Expand Down Expand Up @@ -147,7 +146,7 @@ def set_lmp(

def make_task(
self,
) -> ExplorationTaskGroup:
) -> "CustomizedLmpTemplateTaskGroup":
if not self.conf_set:
raise RuntimeError("confs are not set")
if not self.lmp_set:
Expand All @@ -166,7 +165,7 @@ def make_task(
def _make_customized_task_group(
self,
conf,
) -> ExplorationTaskGroup:
) -> "CustomizedLmpTemplateTaskGroup":
with tempfile.TemporaryDirectory() as tmpdir:
with set_directory(Path(tmpdir)):
Path(self.input_lmp_conf_name).write_text(conf)
Expand Down
3 changes: 1 addition & 2 deletions dpgen2/exploration/task/lmp_template_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)
from .task import (
ExplorationTask,
ExplorationTaskGroup,
)


Expand Down Expand Up @@ -60,7 +59,7 @@ def set_lmp(

def make_task(
self,
) -> ExplorationTaskGroup:
) -> "LmpTemplateTaskGroup":
if not self.conf_set:
raise RuntimeError("confs are not set")
if not self.lmp_set:
Expand Down
3 changes: 1 addition & 2 deletions dpgen2/exploration/task/npt_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from .task import (
ExplorationTask,
ExplorationTaskGroup,
)


Expand Down Expand Up @@ -76,7 +75,7 @@ def set_md(

def make_task(
self,
) -> ExplorationTaskGroup:
) -> "NPTTaskGroup":
"""
Make the LAMMPS task group.
Expand Down
9 changes: 6 additions & 3 deletions dpgen2/exploration/task/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

from .task import (
ExplorationTask,
)
from .task_group import (
BaseExplorationTaskGroup,
ExplorationTaskGroup,
)

Expand Down Expand Up @@ -52,20 +55,20 @@ def add_task_group(

def make_task(
self,
) -> ExplorationTaskGroup:
) -> BaseExplorationTaskGroup:
"""
Make the LAMMPS task group.
Returns
-------
task_grp: ExplorationTaskGroup
task_grp: BaseExplorationTaskGroup
The returned lammps task group. The number of tasks is equal to
the summation of task groups defined by all the exploration groups
added to the stage.
"""

lmp_task_grp = ExplorationTaskGroup()
lmp_task_grp = BaseExplorationTaskGroup()
for ii in self.explor_groups:
# lmp_task_grp.add_group(ii.make_task())
lmp_task_grp += ii.make_task()
Expand Down
91 changes: 0 additions & 91 deletions dpgen2/exploration/task/task.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import os
from abc import (
ABC,
abstractmethod,
)
from collections.abc import (
Sequence,
)
Expand Down Expand Up @@ -58,90 +54,3 @@ def files(self) -> Dict:
The dict storing all files for the task. The file name is a key of the dict, and the file content is the corresponding value.
"""
return self._files


class ExplorationTaskGroup(Sequence):
"""A group of exploration tasks. Implemented as a `list` of `ExplorationTask`."""

def __init__(self):
super().__init__()
self.clear()

def __getitem__(self, ii: int) -> ExplorationTask:
"""Get the `ii`th task"""
return self.task_list[ii]

def __len__(self) -> int:
"""Get the number of tasks in the group"""
return len(self.task_list)

def clear(self) -> None:
self._task_list = []

@property
def task_list(self) -> List[ExplorationTask]:
"""Get the `list` of `ExplorationTask`"""
return self._task_list

def add_task(self, task: ExplorationTask):
"""Add one task to the group."""
self.task_list.append(task)
return self

def add_group(
self,
group: "ExplorationTaskGroup",
):
"""Add another group to the group."""
# see https://www.python.org/dev/peps/pep-0484/#forward-references for forward references
self._task_list = self._task_list + group._task_list
return self

def __add__(
self,
group: "ExplorationTaskGroup",
):
"""Add another group to the group."""
return self.add_group(group)


class FooTask(ExplorationTask):
def __init__(
self,
conf_name="conf.lmp",
conf_cont="",
inpu_name="in.lammps",
inpu_cont="",
):
super().__init__()
self._files = {
conf_name: conf_cont,
inpu_name: inpu_cont,
}


class FooTaskGroup(ExplorationTaskGroup):
def __init__(self, numb_task):
super().__init__()
# TODO: confirm the following is correct
self.tlist = ExplorationTaskGroup()
for ii in range(numb_task):
self.tlist.add_task(
FooTask(
f"conf.{ii}",
f"this is conf.{ii}",
f"input.{ii}",
f"this is input.{ii}",
)
)

@property
def task_list(self):
return self.tlist


if __name__ == "__main__":
grp = FooTaskGroup(3)
for ii in grp:
fcs = ii.files()
print(fcs)
113 changes: 113 additions & 0 deletions dpgen2/exploration/task/task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from abc import (
ABC,
abstractmethod,
)
from collections.abc import (
Sequence,
)
from typing import (
Dict,
List,
Tuple,
)

from .task import (
ExplorationTask,
)


class BaseExplorationTaskGroup(Sequence):
"""A group of exploration tasks. Implemented as a `list` of `ExplorationTask`."""

def __init__(self):
super().__init__()
self.clear()

def __getitem__(self, ii: int) -> ExplorationTask:
"""Get the `ii`th task"""
return self.task_list[ii]

def __len__(self) -> int:
"""Get the number of tasks in the group"""
return len(self.task_list)

def clear(self) -> None:
self._task_list = []

@property
def task_list(self) -> List[ExplorationTask]:
"""Get the `list` of `ExplorationTask`"""
return self._task_list

def add_task(self, task: ExplorationTask):
"""Add one task to the group."""
self.task_list.append(task)
return self

def add_group(
self,
group: "ExplorationTaskGroup",
):
"""Add another group to the group."""
# see https://www.python.org/dev/peps/pep-0484/#forward-references for forward references
self._task_list = self._task_list + group._task_list
return self

def __add__(
self,
group: "ExplorationTaskGroup",
):
"""Add another group to the group."""
return self.add_group(group)


class ExplorationTaskGroup(ABC, BaseExplorationTaskGroup):
def __init__(self):
super().__init__()

@abstractmethod
def make_task(self) -> "ExplorationTaskGroup":
"""Make the task group."""
pass


class FooTask(ExplorationTask):
def __init__(
self,
conf_name="conf.lmp",
conf_cont="",
inpu_name="in.lammps",
inpu_cont="",
):
super().__init__()
self._files = {
conf_name: conf_cont,
inpu_name: inpu_cont,
}


class FooTaskGroup(BaseExplorationTaskGroup):
def __init__(self, numb_task):
super().__init__()
# TODO: confirm the following is correct
self.tlist = BaseExplorationTaskGroup()
for ii in range(numb_task):
self.tlist.add_task(
FooTask(
f"conf.{ii}",
f"this is conf.{ii}",
f"input.{ii}",
f"this is input.{ii}",
)
)

@property
def task_list(self):
return self.tlist


if __name__ == "__main__":
grp = FooTaskGroup(3)
for ii in grp:
fcs = ii.files()
print(fcs)
9 changes: 9 additions & 0 deletions tests/mocked_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,9 @@ def __init__(self):
)
self.add_task(tt)

def make_task(self):
raise NotImplementedError


class MockedExplorationTaskGroup1(ExplorationTaskGroup):
def __init__(self):
Expand All @@ -801,6 +804,9 @@ def __init__(self):
)
self.add_task(tt)

def make_task(self):
raise NotImplementedError


class MockedExplorationTaskGroup2(ExplorationTaskGroup):
def __init__(self):
Expand All @@ -813,6 +819,9 @@ def __init__(self):
)
self.add_task(tt)

def make_task(self):
raise NotImplementedError


class MockedStage(ExplorationStage):
def make_task(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_prep_run_lmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@
train_task_pattern,
)
from dpgen2.exploration.task import (
BaseExplorationTaskGroup,
ExplorationTask,
ExplorationTaskGroup,
)
from dpgen2.op.prep_lmp import (
PrepLmp,
Expand All @@ -90,7 +90,7 @@


def make_task_group_list(ngrp, ntask_per_grp):
tgrp = ExplorationTaskGroup()
tgrp = BaseExplorationTaskGroup()
for ii in range(ngrp):
for jj in range(ntask_per_grp):
tt = ExplorationTask()
Expand Down

0 comments on commit c0f1973

Please sign in to comment.