diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index f58226c374..131999727a 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Callable import numpy as np from deepmd.common import make_default_mesh @@ -81,7 +81,8 @@ def __init__( "t_force": "o_force:0", "t_virial": "o_virial:0", "t_ae": "o_atom_energy:0", - "t_av": "o_atom_virial:0" + "t_av": "o_atom_virial:0", + "t_descriptor": "o_descriptor:0", }, ) DeepEval.__init__( @@ -175,6 +176,36 @@ def get_dim_fparam(self) -> int: def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this DP.""" return self.daparam + + def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Callable: + """Wrapper method with auto batch size. + + Parameters + ---------- + inner_func : Callable + the method to be wrapped + numb_test: int + number of tests + natoms : int + number of atoms + + Returns + ------- + Callable + the wrapper + """ + if self.auto_batch_size is not None: + def eval_func(*args, **kwargs): + return self.auto_batch_size.execute_all(inner_func, numb_test, natoms, *args, **kwargs) + else: + eval_func = inner_func + return eval_func + + def _get_natoms_and_nframes(self, coords: np.ndarray, atom_types: List[int]) -> Tuple[int, int]: + natoms = len(atom_types) + coords = np.reshape(np.array(coords), [-1, natoms * 3]) + nframes = coords.shape[0] + return natoms, nframes def eval( self, @@ -184,7 +215,7 @@ def eval( atomic: bool = False, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, - efield: Optional[np.ndarray] = None + efield: Optional[np.ndarray] = None, ) -> Tuple[np.ndarray, ...]: """Evaluate the energy, force and virial by using this DP. @@ -231,30 +262,22 @@ def eval( The atomic virial. Only returned when atomic == True """ # reshape coords before getting shape - natoms = len(atom_types) - coords = np.reshape(np.array(coords), [-1, natoms * 3]) - numb_test = coords.shape[0] - if atomic: - if self.modifier_type is not None: - raise RuntimeError('modifier does not support atomic modification') - if self.auto_batch_size is not None: - return self.auto_batch_size.execute_all(self._eval_inner, numb_test, natoms, - coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) - return self._eval_inner(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) - else : - if self.auto_batch_size is not None: - e, f, v = self.auto_batch_size.execute_all(self._eval_inner, numb_test, natoms, - coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) - else: - e, f, v = self._eval_inner(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) - if self.modifier_type is not None: - me, mf, mv = self.dm.eval(coords, cells, atom_types) - e += me.reshape(e.shape) - f += mf.reshape(f.shape) - v += mv.reshape(v.shape) - return e, f, v + natoms, numb_test = self._get_natoms_and_nframes(coords, atom_types) + output = self._eval_func(self._eval_inner, numb_test, natoms)(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) - def _eval_inner( + if self.modifier_type is not None: + if atomic: + raise RuntimeError('modifier does not support atomic modification') + me, mf, mv = self.dm.eval(coords, cells, atom_types) + output = list(output) # tuple to list + e, f, v = output[:3] + output[0] += me.reshape(e.shape) + output[1] += mf.reshape(f.shape) + output[2] += mv.reshape(v.shape) + output = tuple(output) + return output + + def _prepare_feed_dict( self, coords, cells, @@ -265,10 +288,9 @@ def _eval_inner( efield=None ): # standarize the shape of inputs + natoms, nframes = self._get_natoms_and_nframes(coords, atom_types) atom_types = np.array(atom_types, dtype = int).reshape([-1]) - natoms = atom_types.size coords = np.reshape(np.array(coords), [-1, natoms * 3]) - nframes = coords.shape[0] if cells is None: pbc = False # make cells to work around the requirement of pbc @@ -322,13 +344,6 @@ def _eval_inner( feed_dict_test = {} feed_dict_test[self.t_natoms] = natoms_vec feed_dict_test[self.t_type ] = np.tile(atom_types, [nframes, 1]).reshape([-1]) - t_out = [self.t_energy, - self.t_force, - self.t_virial] - if atomic : - t_out += [self.t_ae, - self.t_av] - feed_dict_test[self.t_coord] = np.reshape(coords, [-1]) feed_dict_test[self.t_box ] = np.reshape(cells , [-1]) if self.has_efield: @@ -341,6 +356,27 @@ def _eval_inner( feed_dict_test[self.t_fparam] = np.reshape(fparam, [-1]) if self.has_aparam: feed_dict_test[self.t_aparam] = np.reshape(aparam, [-1]) + return feed_dict_test, imap + + def _eval_inner( + self, + coords, + cells, + atom_types, + fparam=None, + aparam=None, + atomic=False, + efield=None + ): + natoms, nframes = self._get_natoms_and_nframes(coords, atom_types) + feed_dict_test, imap = self._prepare_feed_dict(coords, cells, atom_types, fparam, aparam, efield) + t_out = [self.t_energy, + self.t_force, + self.t_virial] + if atomic : + t_out += [self.t_ae, + self.t_av] + v_out = run_sess(self.sess, t_out, feed_dict = feed_dict_test) energy = v_out[0] force = v_out[1] @@ -364,3 +400,62 @@ def _eval_inner( return energy, force, virial, ae, av else : return energy, force, virial + + def eval_descriptor(self, + coords: np.ndarray, + cells: np.ndarray, + atom_types: List[int], + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + efield: Optional[np.ndarray] = None, + ) -> np.array: + """Evaluate descriptors by using this DP. + + Parameters + ---------- + coords + The coordinates of atoms. + The array should be of size nframes x natoms x 3 + cells + The cell of the region. + If None then non-PBC is assumed, otherwise using PBC. + The array should be of size nframes x 9 + atom_types + The atom types + The list should contain natoms ints + fparam + The frame parameter. + The array can be of size : + - nframes x dim_fparam. + - dim_fparam. Then all frames are assumed to be provided with the same fparam. + aparam + The atomic parameter + The array can be of size : + - nframes x natoms x dim_aparam. + - natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam. + - dim_aparam. Then all frames and atoms are provided with the same aparam. + efield + The external field on atoms. + The array should be of size nframes x natoms x 3 + + Returns + ------- + descriptor + Descriptors. + """ + natoms, numb_test = self._get_natoms_and_nframes(coords, atom_types) + descriptor = self._eval_func(self._eval_descriptor_inner, numb_test, natoms)(coords, cells, atom_types, fparam = fparam, aparam = aparam, efield = efield) + return descriptor + + def _eval_descriptor_inner(self, + coords: np.ndarray, + cells: np.ndarray, + atom_types: List[int], + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + efield: Optional[np.ndarray] = None, + ) -> np.array: + natoms, nframes = self._get_natoms_and_nframes(coords, atom_types) + feed_dict_test, imap = self._prepare_feed_dict(coords, cells, atom_types, fparam, aparam, efield) + descriptor, = run_sess(self.sess, [self.t_descriptor], feed_dict = feed_dict_test) + return self.reverse_map(np.reshape(descriptor, [nframes, natoms, -1]), imap) diff --git a/source/tests/infer/deeppot_descpt.txt b/source/tests/infer/deeppot_descpt.txt new file mode 100644 index 0000000000..d757dc7d08 --- /dev/null +++ b/source/tests/infer/deeppot_descpt.txt @@ -0,0 +1,6 @@ +1.321519961572583446e+00 1.371086996432463234e+00 1.343885362068737654e+00 7.030511147133322591e-01 1.371086996432463234e+00 1.522728277961870713e+00 1.495890684109977053e+00 7.457809775555739318e-01 1.343885362068737654e+00 1.495890684109977053e+00 1.469632964943500042e+00 7.315484240279518380e-01 7.030511147133322591e-01 7.457809775555739318e-01 7.315484240279518380e-01 3.769423714838210926e-01 1.397448213242690862e+00 1.406307843425486315e+00 1.376945776785887476e+00 7.364224033531289182e-01 1.349677138676806498e+00 1.501489244191841710e+00 1.475108126068265024e+00 7.345932819566267646e-01 1.428791111325009577e+00 1.527566663819698745e+00 1.498782670338390632e+00 7.675267547771572607e-01 9.945305114958860049e-01 1.119429391438308219e+00 1.100175320069580742e+00 5.435256973435016459e-01 +1.857190745941245336e+00 1.712653950602609498e+00 1.669033760849789605e+00 9.115932216522792952e-01 1.712653950602609498e+00 1.774801034363276520e+00 1.737803632150541455e+00 8.652896470799232853e-01 1.669033760849789605e+00 1.737803632150541455e+00 1.701914413913524937e+00 8.444842527320339798e-01 9.115932216522792952e-01 8.652896470799232853e-01 8.444842527320339798e-01 4.518552761259164718e-01 2.047343778537022985e+00 1.803503653609380475e+00 1.754050188943322208e+00 9.944616692818087911e-01 1.654654031531220149e+00 1.723583984166921823e+00 1.688052380342037084e+00 8.375386196249569037e-01 1.904925986059241572e+00 1.845391378434923402e+00 1.802137505101731207e+00 9.463545327018678677e-01 1.195956664712712669e+00 1.272969241553762787e+00 1.247794584173576471e+00 6.091661698149348769e-01 +1.729318681716640826e+00 1.620316206148274540e+00 1.578053322854602758e+00 8.384315939852660104e-01 1.620316206148274540e+00 1.676382173662613884e+00 1.639223989973451090e+00 7.989806603154945286e-01 1.578053322854602758e+00 1.639223989973451090e+00 1.603163928836319085e+00 7.787623774594047976e-01 8.384315939852660104e-01 7.989806603154945286e-01 7.787623774594047976e-01 4.083319738392942599e-01 1.895061711055783693e+00 1.704719175253225805e+00 1.657289874036961708e+00 9.129654774778642734e-01 1.563778651834606404e+00 1.623994208975480413e+00 1.588269534429824548e+00 7.718070758411798016e-01 1.784871498937309786e+00 1.743282784333909374e+00 1.700760863129526790e+00 8.714582672042692213e-01 1.128335098273277159e+00 1.191704502556441003e+00 1.166245241817210010e+00 5.588000601315349369e-01 +1.666120192454127125e+00 1.691592960101290677e+00 1.660848078254600457e+00 9.148871223997685487e-01 1.691592960101290677e+00 1.807557128537562008e+00 1.777443773744786570e+00 9.384841989100006776e-01 1.660848078254600457e+00 1.777443773744786570e+00 1.747912040379999921e+00 9.217134916614336815e-01 9.148871223997685487e-01 9.384841989100006776e-01 9.217134916614336815e-01 5.034483776550070511e-01 1.784463156540595508e+00 1.770428626353183876e+00 1.736989815040471008e+00 9.755029855694796748e-01 1.671216271483192850e+00 1.786798962823082038e+00 1.757059913333324896e+00 9.272948493374369994e-01 1.787952366477808974e+00 1.855120330291613273e+00 1.822612777927184124e+00 9.860466542466375106e-01 1.240466553399573124e+00 1.335597814608440181e+00 1.313636073416146965e+00 6.893007932368097057e-01 +2.311436366095430905e+00 2.106967028437471523e+00 2.059250658547694179e+00 1.173357457272029780e+00 2.106967028437471523e+00 2.099685185993240832e+00 2.058804265227755614e+00 1.088540185250911785e+00 2.059250658547694179e+00 2.058804265227755614e+00 2.018985089841780045e+00 1.064745413335918212e+00 1.173357457272029780e+00 1.088540185250911785e+00 1.064745413335918212e+00 5.982334090189560527e-01 2.562378776048241047e+00 2.257458645273420217e+00 2.203404394497280983e+00 1.292393223640032840e+00 2.047955802464616948e+00 2.047105830421261707e+00 2.007521807924319557e+00 1.058951143245145721e+00 2.363728390448381234e+00 2.235568922662932501e+00 2.187964141514865180e+00 1.208516329724113270e+00 1.489466008127536600e+00 1.511668501789738439e+00 1.483254366306933525e+00 7.727514556258590073e-01 +2.129030587855572421e+00 1.947158041374798865e+00 1.900213202229163123e+00 1.062720944427532066e+00 1.947158041374798865e+00 1.974765866918151369e+00 1.934939878705544514e+00 9.958975115167185699e-01 1.900213202229163123e+00 1.934939878705544514e+00 1.896230134293759750e+00 9.730461889970962730e-01 1.062720944427532066e+00 9.958975115167185699e-01 9.730461889970962730e-01 5.346713914730781836e-01 2.358302986474245078e+00 2.072936997733268782e+00 2.019613016300256803e+00 1.166931775412306971e+00 1.889259211229606494e+00 1.923865183776938270e+00 1.885410865441944139e+00 9.676424835630113019e-01 2.180008917694376436e+00 2.081759028316970461e+00 2.035118746221280528e+00 1.099152505661615153e+00 1.366637640485368177e+00 1.418420572687418169e+00 1.391094495316195001e+00 7.036614101584164338e-01 diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 2ce4ba1193..eeb42ae113 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -133,6 +133,10 @@ def test_1frame_atm(self): expected_sv = np.sum(self.expected_v.reshape([nframes, -1, 9]), axis = 1) np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), default_places) + def test_descriptor(self): + descpt = self.dp.eval_descriptor(self.coords, self.box, self.atype) + expected_descpt = np.loadtxt(str(tests_path / "infer" / "deeppot_descpt.txt")) + np.testing.assert_almost_equal(descpt.ravel(), expected_descpt.ravel()) def test_2frame_atm(self): coords2 = np.concatenate((self.coords, self.coords))