-
Notifications
You must be signed in to change notification settings - Fork 487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add an interface to eval descriptors #1483
Changes from 10 commits
650c878
764b693
5a30943
c99f6f2
a599e6e
a8fba94
e332de3
11d26e5
3062880
569c61a
a55f10b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import logging | ||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union | ||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Callable | ||
|
||
import numpy as np | ||
from deepmd.common import make_default_mesh | ||
|
@@ -81,7 +81,8 @@ def __init__( | |
"t_force": "o_force:0", | ||
"t_virial": "o_virial:0", | ||
"t_ae": "o_atom_energy:0", | ||
"t_av": "o_atom_virial:0" | ||
"t_av": "o_atom_virial:0", | ||
"t_descriptor": "o_descriptor:0", | ||
}, | ||
) | ||
DeepEval.__init__( | ||
|
@@ -175,6 +176,36 @@ def get_dim_fparam(self) -> int: | |
def get_dim_aparam(self) -> int: | ||
"""Get the number (dimension) of atomic parameters of this DP.""" | ||
return self.daparam | ||
|
||
def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Callable: | ||
"""Wrapper method with auto batch size. | ||
|
||
Parameters | ||
---------- | ||
inner_func : Callable | ||
the method to be wrapped | ||
numb_test: int | ||
number of tests | ||
natoms : int | ||
number of atoms | ||
|
||
Returns | ||
------- | ||
Callable | ||
the wrapper | ||
""" | ||
if self.auto_batch_size is not None: | ||
def eval_func(*args, **kwargs): | ||
return self.auto_batch_size.execute_all(inner_func, numb_test, natoms, *args, **kwargs) | ||
else: | ||
eval_func = inner_func | ||
return eval_func | ||
|
||
def _get_natoms_and_nframes(self, coords: np.ndarray, atom_types: List[int]) -> Tuple[int, int]: | ||
natoms = len(atom_types) | ||
coords = np.reshape(np.array(coords), [-1, natoms * 3]) | ||
nframes = coords.shape[0] | ||
return natoms, nframes | ||
|
||
def eval( | ||
self, | ||
|
@@ -184,7 +215,7 @@ def eval( | |
atomic: bool = False, | ||
fparam: Optional[np.ndarray] = None, | ||
aparam: Optional[np.ndarray] = None, | ||
efield: Optional[np.ndarray] = None | ||
efield: Optional[np.ndarray] = None, | ||
) -> Tuple[np.ndarray, ...]: | ||
"""Evaluate the energy, force and virial by using this DP. | ||
|
||
|
@@ -231,30 +262,20 @@ def eval( | |
The atomic virial. Only returned when atomic == True | ||
""" | ||
# reshape coords before getting shape | ||
natoms = len(atom_types) | ||
coords = np.reshape(np.array(coords), [-1, natoms * 3]) | ||
numb_test = coords.shape[0] | ||
if atomic: | ||
if self.modifier_type is not None: | ||
raise RuntimeError('modifier does not support atomic modification') | ||
if self.auto_batch_size is not None: | ||
return self.auto_batch_size.execute_all(self._eval_inner, numb_test, natoms, | ||
coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) | ||
return self._eval_inner(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) | ||
else : | ||
if self.auto_batch_size is not None: | ||
e, f, v = self.auto_batch_size.execute_all(self._eval_inner, numb_test, natoms, | ||
coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) | ||
else: | ||
e, f, v = self._eval_inner(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) | ||
if self.modifier_type is not None: | ||
me, mf, mv = self.dm.eval(coords, cells, atom_types) | ||
e += me.reshape(e.shape) | ||
f += mf.reshape(f.shape) | ||
v += mv.reshape(v.shape) | ||
return e, f, v | ||
natoms, numb_test = self._get_natoms_and_nframes(coords, atom_types) | ||
output = self._eval_func(self._eval_inner, numb_test, natoms)(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) | ||
|
||
def _eval_inner( | ||
if self.modifier_type is not None: | ||
if atomic: | ||
raise RuntimeError('modifier does not support atomic modification') | ||
me, mf, mv = self.dm.eval(coords, cells, atom_types) | ||
e, f, v = output[:3] | ||
e += me.reshape(e.shape) | ||
f += mf.reshape(f.shape) | ||
v += mv.reshape(v.shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think so. Here is a minimal example: >>> import numpy as np
>>> output = (np.ones(3), np.ones(3), np.ones(3), np.ones(3), np.ones(3))
>>> e,f,v = output[:3]
>>> e += 1
>>> f += 2
>>> v += 3
>>> output
(array([2., 2., 2.]), array([3., 3., 3.]), array([4., 4., 4.]), array([1., 1., 1.]), array([1., 1., 1.])) See https://stackoverflow.com/a/35910888/9567349. In Python, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. by In [1]: a = [1,2,3]
In [2]: a[2]+=1
In [3]: a
Out[3]: [1, 2, 4]
In [4]: e,f,v=a[:3]
In [5]: e+=1
In [6]: f+=1
In [7]: v+=1
In [8]: a
Out[8]: [1, 2, 4]
In [9]: e, f, v
Out[9]: (2, 3, 5) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can reproduce your example. output[0] += me.reshape(e.shape)
... |
||
return output | ||
|
||
def _prepare_feed_dict( | ||
self, | ||
coords, | ||
cells, | ||
|
@@ -265,10 +286,9 @@ def _eval_inner( | |
efield=None | ||
): | ||
# standarize the shape of inputs | ||
natoms, nframes = self._get_natoms_and_nframes(coords, atom_types) | ||
atom_types = np.array(atom_types, dtype = int).reshape([-1]) | ||
natoms = atom_types.size | ||
coords = np.reshape(np.array(coords), [-1, natoms * 3]) | ||
nframes = coords.shape[0] | ||
if cells is None: | ||
pbc = False | ||
# make cells to work around the requirement of pbc | ||
|
@@ -322,13 +342,6 @@ def _eval_inner( | |
feed_dict_test = {} | ||
feed_dict_test[self.t_natoms] = natoms_vec | ||
feed_dict_test[self.t_type ] = np.tile(atom_types, [nframes, 1]).reshape([-1]) | ||
t_out = [self.t_energy, | ||
self.t_force, | ||
self.t_virial] | ||
if atomic : | ||
t_out += [self.t_ae, | ||
self.t_av] | ||
|
||
feed_dict_test[self.t_coord] = np.reshape(coords, [-1]) | ||
feed_dict_test[self.t_box ] = np.reshape(cells , [-1]) | ||
if self.has_efield: | ||
|
@@ -341,6 +354,27 @@ def _eval_inner( | |
feed_dict_test[self.t_fparam] = np.reshape(fparam, [-1]) | ||
if self.has_aparam: | ||
feed_dict_test[self.t_aparam] = np.reshape(aparam, [-1]) | ||
return feed_dict_test, imap | ||
|
||
def _eval_inner( | ||
self, | ||
coords, | ||
cells, | ||
atom_types, | ||
fparam=None, | ||
aparam=None, | ||
atomic=False, | ||
efield=None | ||
): | ||
natoms, nframes = self._get_natoms_and_nframes(coords, atom_types) | ||
feed_dict_test, imap = self._prepare_feed_dict(coords, cells, atom_types, fparam, aparam, efield) | ||
t_out = [self.t_energy, | ||
self.t_force, | ||
self.t_virial] | ||
if atomic : | ||
t_out += [self.t_ae, | ||
self.t_av] | ||
|
||
v_out = run_sess(self.sess, t_out, feed_dict = feed_dict_test) | ||
energy = v_out[0] | ||
force = v_out[1] | ||
|
@@ -364,3 +398,62 @@ def _eval_inner( | |
return energy, force, virial, ae, av | ||
else : | ||
return energy, force, virial | ||
|
||
def eval_descriptor(self, | ||
coords: np.ndarray, | ||
cells: np.ndarray, | ||
atom_types: List[int], | ||
fparam: Optional[np.ndarray] = None, | ||
aparam: Optional[np.ndarray] = None, | ||
efield: Optional[np.ndarray] = None, | ||
) -> np.array: | ||
"""Evaluate descriptors by using this DP. | ||
|
||
Parameters | ||
---------- | ||
coords | ||
The coordinates of atoms. | ||
The array should be of size nframes x natoms x 3 | ||
cells | ||
The cell of the region. | ||
If None then non-PBC is assumed, otherwise using PBC. | ||
The array should be of size nframes x 9 | ||
atom_types | ||
The atom types | ||
The list should contain natoms ints | ||
fparam | ||
The frame parameter. | ||
The array can be of size : | ||
- nframes x dim_fparam. | ||
- dim_fparam. Then all frames are assumed to be provided with the same fparam. | ||
aparam | ||
The atomic parameter | ||
The array can be of size : | ||
- nframes x natoms x dim_aparam. | ||
- natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam. | ||
- dim_aparam. Then all frames and atoms are provided with the same aparam. | ||
efield | ||
The external field on atoms. | ||
The array should be of size nframes x natoms x 3 | ||
|
||
Returns | ||
------- | ||
descriptor | ||
Descriptors. | ||
""" | ||
natoms, numb_test = self._get_natoms_and_nframes(coords, atom_types) | ||
descriptor = self._eval_func(self._eval_descriptor_inner, numb_test, natoms)(coords, cells, atom_types, fparam = fparam, aparam = aparam, efield = efield) | ||
return descriptor | ||
|
||
def _eval_descriptor_inner(self, | ||
coords: np.ndarray, | ||
cells: np.ndarray, | ||
atom_types: List[int], | ||
fparam: Optional[np.ndarray] = None, | ||
aparam: Optional[np.ndarray] = None, | ||
efield: Optional[np.ndarray] = None, | ||
) -> np.array: | ||
natoms, nframes = self._get_natoms_and_nframes(coords, atom_types) | ||
feed_dict_test, imap = self._prepare_feed_dict(coords, cells, atom_types, fparam, aparam, efield) | ||
descriptor, = run_sess(self.sess, [self.t_descriptor], feed_dict = feed_dict_test) | ||
return self.reverse_map(np.reshape(descriptor, [nframes, natoms, -1]), imap) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
1.321519961572583446e+00 1.371086996432463234e+00 1.343885362068737654e+00 7.030511147133322591e-01 1.371086996432463234e+00 1.522728277961870713e+00 1.495890684109977053e+00 7.457809775555739318e-01 1.343885362068737654e+00 1.495890684109977053e+00 1.469632964943500042e+00 7.315484240279518380e-01 7.030511147133322591e-01 7.457809775555739318e-01 7.315484240279518380e-01 3.769423714838210926e-01 1.397448213242690862e+00 1.406307843425486315e+00 1.376945776785887476e+00 7.364224033531289182e-01 1.349677138676806498e+00 1.501489244191841710e+00 1.475108126068265024e+00 7.345932819566267646e-01 1.428791111325009577e+00 1.527566663819698745e+00 1.498782670338390632e+00 7.675267547771572607e-01 9.945305114958860049e-01 1.119429391438308219e+00 1.100175320069580742e+00 5.435256973435016459e-01 | ||
1.857190745941245336e+00 1.712653950602609498e+00 1.669033760849789605e+00 9.115932216522792952e-01 1.712653950602609498e+00 1.774801034363276520e+00 1.737803632150541455e+00 8.652896470799232853e-01 1.669033760849789605e+00 1.737803632150541455e+00 1.701914413913524937e+00 8.444842527320339798e-01 9.115932216522792952e-01 8.652896470799232853e-01 8.444842527320339798e-01 4.518552761259164718e-01 2.047343778537022985e+00 1.803503653609380475e+00 1.754050188943322208e+00 9.944616692818087911e-01 1.654654031531220149e+00 1.723583984166921823e+00 1.688052380342037084e+00 8.375386196249569037e-01 1.904925986059241572e+00 1.845391378434923402e+00 1.802137505101731207e+00 9.463545327018678677e-01 1.195956664712712669e+00 1.272969241553762787e+00 1.247794584173576471e+00 6.091661698149348769e-01 | ||
1.729318681716640826e+00 1.620316206148274540e+00 1.578053322854602758e+00 8.384315939852660104e-01 1.620316206148274540e+00 1.676382173662613884e+00 1.639223989973451090e+00 7.989806603154945286e-01 1.578053322854602758e+00 1.639223989973451090e+00 1.603163928836319085e+00 7.787623774594047976e-01 8.384315939852660104e-01 7.989806603154945286e-01 7.787623774594047976e-01 4.083319738392942599e-01 1.895061711055783693e+00 1.704719175253225805e+00 1.657289874036961708e+00 9.129654774778642734e-01 1.563778651834606404e+00 1.623994208975480413e+00 1.588269534429824548e+00 7.718070758411798016e-01 1.784871498937309786e+00 1.743282784333909374e+00 1.700760863129526790e+00 8.714582672042692213e-01 1.128335098273277159e+00 1.191704502556441003e+00 1.166245241817210010e+00 5.588000601315349369e-01 | ||
1.666120192454127125e+00 1.691592960101290677e+00 1.660848078254600457e+00 9.148871223997685487e-01 1.691592960101290677e+00 1.807557128537562008e+00 1.777443773744786570e+00 9.384841989100006776e-01 1.660848078254600457e+00 1.777443773744786570e+00 1.747912040379999921e+00 9.217134916614336815e-01 9.148871223997685487e-01 9.384841989100006776e-01 9.217134916614336815e-01 5.034483776550070511e-01 1.784463156540595508e+00 1.770428626353183876e+00 1.736989815040471008e+00 9.755029855694796748e-01 1.671216271483192850e+00 1.786798962823082038e+00 1.757059913333324896e+00 9.272948493374369994e-01 1.787952366477808974e+00 1.855120330291613273e+00 1.822612777927184124e+00 9.860466542466375106e-01 1.240466553399573124e+00 1.335597814608440181e+00 1.313636073416146965e+00 6.893007932368097057e-01 | ||
2.311436366095430905e+00 2.106967028437471523e+00 2.059250658547694179e+00 1.173357457272029780e+00 2.106967028437471523e+00 2.099685185993240832e+00 2.058804265227755614e+00 1.088540185250911785e+00 2.059250658547694179e+00 2.058804265227755614e+00 2.018985089841780045e+00 1.064745413335918212e+00 1.173357457272029780e+00 1.088540185250911785e+00 1.064745413335918212e+00 5.982334090189560527e-01 2.562378776048241047e+00 2.257458645273420217e+00 2.203404394497280983e+00 1.292393223640032840e+00 2.047955802464616948e+00 2.047105830421261707e+00 2.007521807924319557e+00 1.058951143245145721e+00 2.363728390448381234e+00 2.235568922662932501e+00 2.187964141514865180e+00 1.208516329724113270e+00 1.489466008127536600e+00 1.511668501789738439e+00 1.483254366306933525e+00 7.727514556258590073e-01 | ||
2.129030587855572421e+00 1.947158041374798865e+00 1.900213202229163123e+00 1.062720944427532066e+00 1.947158041374798865e+00 1.974765866918151369e+00 1.934939878705544514e+00 9.958975115167185699e-01 1.900213202229163123e+00 1.934939878705544514e+00 1.896230134293759750e+00 9.730461889970962730e-01 1.062720944427532066e+00 9.958975115167185699e-01 9.730461889970962730e-01 5.346713914730781836e-01 2.358302986474245078e+00 2.072936997733268782e+00 2.019613016300256803e+00 1.166931775412306971e+00 1.889259211229606494e+00 1.923865183776938270e+00 1.885410865441944139e+00 9.676424835630113019e-01 2.180008917694376436e+00 2.081759028316970461e+00 2.035118746221280528e+00 1.099152505661615153e+00 1.366637640485368177e+00 1.418420572687418169e+00 1.391094495316195001e+00 7.036614101584164338e-01 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason why
auto_batch_size
is removed?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not removed. Instead, it's moved into a wrapper
eval_func
to make the code more clear.