forked from dptech-corp/dpgen2
-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refact configuration selection (#109)
This PR refactor the configuration selection of dpgen2. Some key changes are - `TrajRender` to render the trajectory and model deviation of exploration output. - `ExplorationReportTrustLevels` provides the rules of configuration selection. It 1. analyzes the model deviation 2. tells which frames are selected 3. supports summary printing - `ExplorationReportTrustLevels` does not depend on the format of trajectory and model deviation - `ConfSelectorLammpsFrames` -> `ConfSelectorFrames` uses render and report to select configurations from trajectories. Its implementation also does not depend on the format of trajectory and model deviation. Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
- Loading branch information
1 parent
62f0a0d
commit 5465850
Showing
18 changed files
with
632 additions
and
395 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .traj_render import ( | ||
TrajRender, | ||
) | ||
from .traj_render_lammps import ( | ||
TrajRenderLammps, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import dpdata | ||
import numpy as np | ||
from typing import ( | ||
List, | ||
Optional, | ||
Tuple, | ||
Union, | ||
) | ||
from abc import ABC, abstractmethod | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING | ||
if TYPE_CHECKING: | ||
from dpgen2.exploration.selector import ConfFilters | ||
|
||
|
||
class TrajRender(ABC): | ||
@abstractmethod | ||
def get_model_devi( | ||
self, | ||
files : List[Path], | ||
) -> Tuple[List[np.ndarray], Union[List[np.ndarray], None]]: | ||
r"""Get model deviations from recording files. | ||
Parameters | ||
---------- | ||
files: List[Path] | ||
The paths to the model deviation recording files | ||
Returns | ||
------- | ||
model_devis: Tuple[List[np.array], Union[List[np.array],None]] | ||
A tuple. model_devis[0] is the force model deviations, | ||
model_devis[1] is the virial model deviations. | ||
The model_devis[1] can be None. | ||
If not None, model_devis[i] is List[np.array], where np.array is a | ||
one-dimensional array. | ||
The first dimension of model_devis[i] is the trajectory | ||
(same size as len(files)), while the second dimension is the frame. | ||
""" | ||
pass | ||
|
||
|
||
@abstractmethod | ||
def get_confs( | ||
self, | ||
traj: List[Path], | ||
id_selected: List[List[int]], | ||
type_map: Optional[List[str]] = None, | ||
conf_filters: Optional["ConfFilters"] = None, | ||
) -> dpdata.MultiSystems : | ||
r"""Get configurations from trajectory by selection. | ||
Parameters | ||
---------- | ||
traj: List[Path] | ||
Trajectory files | ||
id_selected: List[List[int]] | ||
The selected frames. id_selected[ii][jj] is the jj-th selected frame | ||
from the ii-th trajectory. id_selected[ii] may be an empty list. | ||
type_map: List[str] | ||
The type map. | ||
Returns | ||
------- | ||
ms: dpdata.MultiSystems | ||
The configurations in dpdata.MultiSystems format | ||
""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import numpy as np | ||
import dpdata | ||
from typing import ( | ||
List, | ||
Optional, | ||
Tuple, | ||
Union, | ||
) | ||
from .traj_render import TrajRender | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING | ||
if TYPE_CHECKING: | ||
from dpgen2.exploration.selector import ConfFilters | ||
|
||
|
||
class TrajRenderLammps(TrajRender): | ||
def __init__( | ||
self, | ||
nopbc : bool=False, | ||
): | ||
self.nopbc = nopbc | ||
|
||
def get_model_devi( | ||
self, | ||
files : List[Path], | ||
) -> Tuple[List[np.ndarray], Union[List[np.ndarray], None]]: | ||
nframes = len(files) | ||
mdfs = [] | ||
mdvs = [] | ||
for ii in range(nframes): | ||
mdf, mdv = self._load_one_model_devi(files[ii]) | ||
mdfs.append(mdf) | ||
mdvs.append(mdv) | ||
return mdfs, mdvs | ||
|
||
def _load_one_model_devi(self, fname): | ||
dd = np.loadtxt(fname) | ||
return dd[:,4], dd[:,1] | ||
|
||
def get_confs( | ||
self, | ||
trajs: List[Path], | ||
id_selected: List[List[int]], | ||
type_map: Optional[List[str]] = None, | ||
conf_filters: Optional["ConfFilters"] = None, | ||
) -> dpdata.MultiSystems : | ||
del conf_filters # by far does not support conf filters | ||
ntraj = len(trajs) | ||
traj_fmt = 'lammps/dump' | ||
ms = dpdata.MultiSystems(type_map=type_map) | ||
for ii in range(ntraj): | ||
if len(id_selected[ii]) > 0: | ||
ss = dpdata.System(trajs[ii], fmt=traj_fmt, type_map=type_map) | ||
ss.nopbc = self.nopbc | ||
ss = ss.sub_system(id_selected[ii]) | ||
ms.append(ss) | ||
return ms |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,75 @@ | ||
import numpy as np | ||
from abc import ABC, abstractmethod | ||
from typing import Tuple | ||
from typing import ( | ||
Tuple, List, Optional, | ||
) | ||
|
||
class ExplorationReport(ABC): | ||
def __init__(self): | ||
@abstractmethod | ||
def clear(self): | ||
r"""Clear the report""" | ||
pass | ||
|
||
@abstractmethod | ||
def failed_ratio ( | ||
self, | ||
tag = None, | ||
) -> float : | ||
def record( | ||
self, | ||
md_f : List[np.ndarray], | ||
md_v : Optional[List[np.ndarray]] = None, | ||
): | ||
r"""Record the model deviations of the trajectories | ||
Parameters | ||
---------- | ||
mdf : List[np.ndarray] | ||
The force model deviations. mdf[ii][jj] is the force model deviation | ||
of the jj-th frame of the ii-th trajectory. | ||
mdv : Optional[List[np.ndarray]] | ||
The virial model deviations. mdv[ii][jj] is the virial model deviation | ||
of the jj-th frame of the ii-th trajectory. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def accurate_ratio ( | ||
self, | ||
tag = None, | ||
) -> float : | ||
def converged(self) -> bool : | ||
r"""If the exploration is converged""" | ||
pass | ||
|
||
def no_candidate(self) -> bool: | ||
r"""If no candidate configuration is found""" | ||
return all([ len(ii) == 0 for ii in self.get_candidate_ids()]) | ||
|
||
@abstractmethod | ||
def candidate_ratio ( | ||
self, | ||
tag = None, | ||
) -> float : | ||
def get_candidate_ids( | ||
self, | ||
max_nframes : Optional[int] = None, | ||
) -> List[List[int]]: | ||
r"""Get indexes of candidate configurations | ||
Parameters | ||
---------- | ||
max_nframes int | ||
The maximal number of frames of candidates. | ||
Returns | ||
------- | ||
idx: List[List[int]] | ||
The frame indices of candidate configurations. | ||
idx[ii][jj] is the frame index of the jj-th candidate of the | ||
ii-th trajectory. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def print_header(self) -> str: | ||
r"""Print the header of report""" | ||
pass | ||
|
||
@abstractmethod | ||
def print( | ||
self, | ||
stage_idx : int, | ||
idx_in_stage : int, | ||
iter_idx : int, | ||
) -> str: | ||
r"""Print the report""" | ||
pass |
Oops, something went wrong.