From 650c8786d1f5ec99ef4fb21b4047c0dd681b47a5 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 18 Feb 2022 15:32:56 -0500 Subject: [PATCH 01/11] add an interface to eval descriptors Fix #1393. --- deepmd/infer/deep_pot.py | 61 ++++++++++++++++----------- source/tests/infer/deeppot_descpt.txt | 6 +++ source/tests/test_deeppot_a.py | 4 ++ 3 files changed, 47 insertions(+), 24 deletions(-) create mode 100644 source/tests/infer/deeppot_descpt.txt diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index f58226c374..ef9d2a8e0e 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -81,7 +81,8 @@ def __init__( "t_force": "o_force:0", "t_virial": "o_virial:0", "t_ae": "o_atom_energy:0", - "t_av": "o_atom_virial:0" + "t_av": "o_atom_virial:0", + "t_descriptor": "o_descriptor:0", }, ) DeepEval.__init__( @@ -184,7 +185,8 @@ def eval( atomic: bool = False, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, - efield: Optional[np.ndarray] = None + efield: Optional[np.ndarray] = None, + eval_descriptor=False, ) -> Tuple[np.ndarray, ...]: """Evaluate the energy, force and virial by using this DP. @@ -216,6 +218,8 @@ def eval( efield The external field on atoms. The array should be of size nframes x natoms x 3 + eval_descriptor : bool + Eval descriptors. Returns ------- @@ -229,30 +233,29 @@ def eval( The atomic energy. Only returned when atomic == True atom_virial The atomic virial. Only returned when atomic == True + descriptor + Descriptors. Only returned when eval_descriptor == True """ # reshape coords before getting shape natoms = len(atom_types) coords = np.reshape(np.array(coords), [-1, natoms * 3]) numb_test = coords.shape[0] - if atomic: - if self.modifier_type is not None: + if self.auto_batch_size is not None: + def eval_func(*args, **kwargs): + return self.auto_batch_size.execute_all(self._eval_inner, numb_test, natoms, *args, **kwargs) + else: + eval_func = self._eval_inner + output = eval_func(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield, eval_descriptor=eval_descriptor) + + if self.modifier_type is not None: + if atomic: raise RuntimeError('modifier does not support atomic modification') - if self.auto_batch_size is not None: - return self.auto_batch_size.execute_all(self._eval_inner, numb_test, natoms, - coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) - return self._eval_inner(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) - else : - if self.auto_batch_size is not None: - e, f, v = self.auto_batch_size.execute_all(self._eval_inner, numb_test, natoms, - coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) - else: - e, f, v = self._eval_inner(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) - if self.modifier_type is not None: - me, mf, mv = self.dm.eval(coords, cells, atom_types) - e += me.reshape(e.shape) - f += mf.reshape(f.shape) - v += mv.reshape(v.shape) - return e, f, v + me, mf, mv = self.dm.eval(coords, cells, atom_types) + e, f, v = output[:3] + e += me.reshape(e.shape) + f += mf.reshape(f.shape) + v += mv.reshape(v.shape) + return output def _eval_inner( self, @@ -262,7 +265,8 @@ def _eval_inner( fparam=None, aparam=None, atomic=False, - efield=None + efield=None, + eval_descriptor=False, ): # standarize the shape of inputs atom_types = np.array(atom_types, dtype = int).reshape([-1]) @@ -328,6 +332,8 @@ def _eval_inner( if atomic : t_out += [self.t_ae, self.t_av] + if eval_descriptor: + t_out.append(self.t_descriptor) feed_dict_test[self.t_coord] = np.reshape(coords, [-1]) feed_dict_test[self.t_box ] = np.reshape(cells , [-1]) @@ -345,9 +351,13 @@ def _eval_inner( energy = v_out[0] force = v_out[1] virial = v_out[2] + t_idx = 3 if atomic: ae = v_out[3] av = v_out[4] + t_idx += 2 + if eval_descriptor: + descriptor = v_out[t_idx] # reverse map of the outputs force = self.reverse_map(np.reshape(force, [nframes,-1,3]), imap) @@ -358,9 +368,12 @@ def _eval_inner( energy = np.reshape(energy, [nframes, 1]) force = np.reshape(force, [nframes, natoms, 3]) virial = np.reshape(virial, [nframes, 9]) + output = [energy, force, virial] if atomic: ae = np.reshape(ae, [nframes, natoms, 1]) av = np.reshape(av, [nframes, natoms, 9]) - return energy, force, virial, ae, av - else : - return energy, force, virial + output.extend([ae, av]) + if eval_descriptor: + descriptor = np.reshape(descriptor, [nframes, natoms, -1]) + output.append(descriptor) + return tuple(output) diff --git a/source/tests/infer/deeppot_descpt.txt b/source/tests/infer/deeppot_descpt.txt new file mode 100644 index 0000000000..1ea18e9a02 --- /dev/null +++ b/source/tests/infer/deeppot_descpt.txt @@ -0,0 +1,6 @@ +1.321519961572583446e+00 1.371086996432463234e+00 1.343885362068737654e+00 7.030511147133322591e-01 1.371086996432463234e+00 1.522728277961870713e+00 1.495890684109977053e+00 7.457809775555739318e-01 1.343885362068737654e+00 1.495890684109977053e+00 1.469632964943500042e+00 7.315484240279518380e-01 7.030511147133322591e-01 7.457809775555739318e-01 7.315484240279518380e-01 3.769423714838210926e-01 1.397448213242690862e+00 1.406307843425486315e+00 1.376945776785887476e+00 7.364224033531289182e-01 1.349677138676806498e+00 1.501489244191841710e+00 1.475108126068265024e+00 7.345932819566267646e-01 1.428791111325009577e+00 1.527566663819698745e+00 1.498782670338390632e+00 7.675267547771572607e-01 9.945305114958860049e-01 1.119429391438308219e+00 1.100175320069580742e+00 5.435256973435016459e-01 +1.666120192454127125e+00 1.691592960101290677e+00 1.660848078254600457e+00 9.148871223997685487e-01 1.691592960101290677e+00 1.807557128537562008e+00 1.777443773744786570e+00 9.384841989100006776e-01 1.660848078254600457e+00 1.777443773744786570e+00 1.747912040379999921e+00 9.217134916614336815e-01 9.148871223997685487e-01 9.384841989100006776e-01 9.217134916614336815e-01 5.034483776550070511e-01 1.784463156540595508e+00 1.770428626353183876e+00 1.736989815040471008e+00 9.755029855694796748e-01 1.671216271483192850e+00 1.786798962823082038e+00 1.757059913333324896e+00 9.272948493374369994e-01 1.787952366477808974e+00 1.855120330291613273e+00 1.822612777927184124e+00 9.860466542466375106e-01 1.240466553399573124e+00 1.335597814608440181e+00 1.313636073416146965e+00 6.893007932368097057e-01 +1.857190745941245336e+00 1.712653950602609498e+00 1.669033760849789605e+00 9.115932216522792952e-01 1.712653950602609498e+00 1.774801034363276520e+00 1.737803632150541455e+00 8.652896470799232853e-01 1.669033760849789605e+00 1.737803632150541455e+00 1.701914413913524937e+00 8.444842527320339798e-01 9.115932216522792952e-01 8.652896470799232853e-01 8.444842527320339798e-01 4.518552761259164718e-01 2.047343778537022985e+00 1.803503653609380475e+00 1.754050188943322208e+00 9.944616692818087911e-01 1.654654031531220149e+00 1.723583984166921823e+00 1.688052380342037084e+00 8.375386196249569037e-01 1.904925986059241572e+00 1.845391378434923402e+00 1.802137505101731207e+00 9.463545327018678677e-01 1.195956664712712669e+00 1.272969241553762787e+00 1.247794584173576471e+00 6.091661698149348769e-01 +1.729318681716640826e+00 1.620316206148274540e+00 1.578053322854602758e+00 8.384315939852660104e-01 1.620316206148274540e+00 1.676382173662613884e+00 1.639223989973451090e+00 7.989806603154945286e-01 1.578053322854602758e+00 1.639223989973451090e+00 1.603163928836319085e+00 7.787623774594047976e-01 8.384315939852660104e-01 7.989806603154945286e-01 7.787623774594047976e-01 4.083319738392942599e-01 1.895061711055783693e+00 1.704719175253225805e+00 1.657289874036961708e+00 9.129654774778642734e-01 1.563778651834606404e+00 1.623994208975480413e+00 1.588269534429824548e+00 7.718070758411798016e-01 1.784871498937309786e+00 1.743282784333909374e+00 1.700760863129526790e+00 8.714582672042692213e-01 1.128335098273277159e+00 1.191704502556441003e+00 1.166245241817210010e+00 5.588000601315349369e-01 +2.311436366095430905e+00 2.106967028437471523e+00 2.059250658547694179e+00 1.173357457272029780e+00 2.106967028437471523e+00 2.099685185993240832e+00 2.058804265227755614e+00 1.088540185250911785e+00 2.059250658547694179e+00 2.058804265227755614e+00 2.018985089841780045e+00 1.064745413335918212e+00 1.173357457272029780e+00 1.088540185250911785e+00 1.064745413335918212e+00 5.982334090189560527e-01 2.562378776048241047e+00 2.257458645273420217e+00 2.203404394497280983e+00 1.292393223640032840e+00 2.047955802464616948e+00 2.047105830421261707e+00 2.007521807924319557e+00 1.058951143245145721e+00 2.363728390448381234e+00 2.235568922662932501e+00 2.187964141514865180e+00 1.208516329724113270e+00 1.489466008127536600e+00 1.511668501789738439e+00 1.483254366306933525e+00 7.727514556258590073e-01 +2.129030587855572421e+00 1.947158041374798865e+00 1.900213202229163123e+00 1.062720944427532066e+00 1.947158041374798865e+00 1.974765866918151369e+00 1.934939878705544514e+00 9.958975115167185699e-01 1.900213202229163123e+00 1.934939878705544514e+00 1.896230134293759750e+00 9.730461889970962730e-01 1.062720944427532066e+00 9.958975115167185699e-01 9.730461889970962730e-01 5.346713914730781836e-01 2.358302986474245078e+00 2.072936997733268782e+00 2.019613016300256803e+00 1.166931775412306971e+00 1.889259211229606494e+00 1.923865183776938270e+00 1.885410865441944139e+00 9.676424835630113019e-01 2.180008917694376436e+00 2.081759028316970461e+00 2.035118746221280528e+00 1.099152505661615153e+00 1.366637640485368177e+00 1.418420572687418169e+00 1.391094495316195001e+00 7.036614101584164338e-01 diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 2ce4ba1193..a6f55bd79a 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -133,6 +133,10 @@ def test_1frame_atm(self): expected_sv = np.sum(self.expected_v.reshape([nframes, -1, 9]), axis = 1) np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), default_places) + def test_descriptor(self): + _, _, _, descpt = self.dp.eval(self.coords, self.box, self.atype, eval_descriptor=True) + expected_descpt = np.loadtxt(str(tests_path / "infer" / "deeppot_descpt.txt")) + np.testing.assert_almost_equal(descpt.ravel(), expected_descpt.ravel()) def test_2frame_atm(self): coords2 = np.concatenate((self.coords, self.coords)) From 764b6936bd9a41f358adf9cd486ee87da24c6421 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 19 Feb 2022 15:18:22 -0500 Subject: [PATCH 02/11] move eval descriptor out of eval --- deepmd/infer/deep_pot.py | 116 ++++++++++++++++++++++++++------- source/tests/test_deeppot_a.py | 2 +- 2 files changed, 93 insertions(+), 25 deletions(-) diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index ef9d2a8e0e..f59fcdad64 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -186,7 +186,6 @@ def eval( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, efield: Optional[np.ndarray] = None, - eval_descriptor=False, ) -> Tuple[np.ndarray, ...]: """Evaluate the energy, force and virial by using this DP. @@ -218,8 +217,6 @@ def eval( efield The external field on atoms. The array should be of size nframes x natoms x 3 - eval_descriptor : bool - Eval descriptors. Returns ------- @@ -233,8 +230,6 @@ def eval( The atomic energy. Only returned when atomic == True atom_virial The atomic virial. Only returned when atomic == True - descriptor - Descriptors. Only returned when eval_descriptor == True """ # reshape coords before getting shape natoms = len(atom_types) @@ -245,7 +240,7 @@ def eval_func(*args, **kwargs): return self.auto_batch_size.execute_all(self._eval_inner, numb_test, natoms, *args, **kwargs) else: eval_func = self._eval_inner - output = eval_func(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield, eval_descriptor=eval_descriptor) + output = eval_func(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) if self.modifier_type is not None: if atomic: @@ -257,7 +252,7 @@ def eval_func(*args, **kwargs): v += mv.reshape(v.shape) return output - def _eval_inner( + def _prepare_feed_dict( self, coords, cells, @@ -266,7 +261,6 @@ def _eval_inner( aparam=None, atomic=False, efield=None, - eval_descriptor=False, ): # standarize the shape of inputs atom_types = np.array(atom_types, dtype = int).reshape([-1]) @@ -326,15 +320,6 @@ def _eval_inner( feed_dict_test = {} feed_dict_test[self.t_natoms] = natoms_vec feed_dict_test[self.t_type ] = np.tile(atom_types, [nframes, 1]).reshape([-1]) - t_out = [self.t_energy, - self.t_force, - self.t_virial] - if atomic : - t_out += [self.t_ae, - self.t_av] - if eval_descriptor: - t_out.append(self.t_descriptor) - feed_dict_test[self.t_coord] = np.reshape(coords, [-1]) feed_dict_test[self.t_box ] = np.reshape(cells , [-1]) if self.has_efield: @@ -347,17 +332,36 @@ def _eval_inner( feed_dict_test[self.t_fparam] = np.reshape(fparam, [-1]) if self.has_aparam: feed_dict_test[self.t_aparam] = np.reshape(aparam, [-1]) + return feed_dict_test + + def _eval_inner( + self, + coords, + cells, + atom_types, + fparam=None, + aparam=None, + atomic=False, + efield=None, + ): + natoms = atom_types.size + coords = np.reshape(np.array(coords), [-1, natoms * 3]) + nframes = coords.shape[0] + feed_dict_test = self._prepare_feed_dict(coords, cells, atom_types, fparam, aparam, efield) + t_out = [self.t_energy, + self.t_force, + self.t_virial] + if atomic : + t_out += [self.t_ae, + self.t_av] + v_out = run_sess(self.sess, t_out, feed_dict = feed_dict_test) energy = v_out[0] force = v_out[1] virial = v_out[2] - t_idx = 3 if atomic: ae = v_out[3] av = v_out[4] - t_idx += 2 - if eval_descriptor: - descriptor = v_out[t_idx] # reverse map of the outputs force = self.reverse_map(np.reshape(force, [nframes,-1,3]), imap) @@ -373,7 +377,71 @@ def _eval_inner( ae = np.reshape(ae, [nframes, natoms, 1]) av = np.reshape(av, [nframes, natoms, 9]) output.extend([ae, av]) - if eval_descriptor: - descriptor = np.reshape(descriptor, [nframes, natoms, -1]) - output.append(descriptor) return tuple(output) + + def eval_descriptor(self, + coords: np.ndarray, + cells: np.ndarray, + atom_types: List[int], + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + efield: Optional[np.ndarray] = None, + ) -> np.array: + """Evaluate descriptors by using this DP. + + Parameters + ---------- + coords + The coordinates of atoms. + The array should be of size nframes x natoms x 3 + cells + The cell of the region. + If None then non-PBC is assumed, otherwise using PBC. + The array should be of size nframes x 9 + atom_types + The atom types + The list should contain natoms ints + fparam + The frame parameter. + The array can be of size : + - nframes x dim_fparam. + - dim_fparam. Then all frames are assumed to be provided with the same fparam. + aparam + The atomic parameter + The array can be of size : + - nframes x natoms x dim_aparam. + - natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam. + - dim_aparam. Then all frames and atoms are provided with the same aparam. + efield + The external field on atoms. + The array should be of size nframes x natoms x 3 + + Returns + ------- + descriptor + Descriptors. + """ + natoms = len(atom_types) + coords = np.reshape(np.array(coords), [-1, natoms * 3]) + numb_test = coords.shape[0] + if self.auto_batch_size is not None: + def eval_func(*args, **kwargs): + return self.auto_batch_size.execute_all(self._eval_descriptor_inner, numb_test, natoms, *args, **kwargs) + else: + eval_func = self._eval_descriptor_inner + output = eval_func(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) + + def _eval_descriptor_inner(self, + coords: np.ndarray, + cells: np.ndarray, + atom_types: List[int], + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + efield: Optional[np.ndarray] = None, + ) -> np.array: + natoms = atom_types.size + coords = np.reshape(np.array(coords), [-1, natoms * 3]) + nframes = coords.shape[0] + feed_dict_test = self._prepare_feed_dict(coords, cells, atom_types, fparam, aparam, efield) + descriptor, = run_sess(self.sess, [self.t_descriptor], feed_dict = feed_dict_test) + return np.reshape(descriptor, [nframes, natoms, -1]) \ No newline at end of file diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index a6f55bd79a..44e41346dc 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -134,7 +134,7 @@ def test_1frame_atm(self): np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), default_places) def test_descriptor(self): - _, _, _, descpt = self.dp.eval(self.coords, self.box, self.atype, eval_descriptor=True) + _, _, _, descpt = self.dp.eval_descriptor(self.coords, self.box, self.atype) expected_descpt = np.loadtxt(str(tests_path / "infer" / "deeppot_descpt.txt")) np.testing.assert_almost_equal(descpt.ravel(), expected_descpt.ravel()) From 5a309431de54615286d0c09c1dd4e60296927724 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 19 Feb 2022 15:25:11 -0500 Subject: [PATCH 03/11] bugfix and remove unnecessary changes --- deepmd/infer/deep_pot.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index f59fcdad64..fe4d71c4d1 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -260,7 +260,7 @@ def _prepare_feed_dict( fparam=None, aparam=None, atomic=False, - efield=None, + efield=None ): # standarize the shape of inputs atom_types = np.array(atom_types, dtype = int).reshape([-1]) @@ -342,7 +342,7 @@ def _eval_inner( fparam=None, aparam=None, atomic=False, - efield=None, + efield=None ): natoms = atom_types.size coords = np.reshape(np.array(coords), [-1, natoms * 3]) @@ -372,12 +372,12 @@ def _eval_inner( energy = np.reshape(energy, [nframes, 1]) force = np.reshape(force, [nframes, natoms, 3]) virial = np.reshape(virial, [nframes, 9]) - output = [energy, force, virial] if atomic: ae = np.reshape(ae, [nframes, natoms, 1]) av = np.reshape(av, [nframes, natoms, 9]) - output.extend([ae, av]) - return tuple(output) + return energy, force, virial, ae, av + else : + return energy, force, virial def eval_descriptor(self, coords: np.ndarray, @@ -429,7 +429,7 @@ def eval_func(*args, **kwargs): return self.auto_batch_size.execute_all(self._eval_descriptor_inner, numb_test, natoms, *args, **kwargs) else: eval_func = self._eval_descriptor_inner - output = eval_func(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) + return eval_func(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) def _eval_descriptor_inner(self, coords: np.ndarray, From c99f6f2bc8d1f8b8215f27d8ac95591033598033 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 19 Feb 2022 15:26:01 -0500 Subject: [PATCH 04/11] fix test --- source/tests/test_deeppot_a.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 44e41346dc..eeb42ae113 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -134,7 +134,7 @@ def test_1frame_atm(self): np.testing.assert_almost_equal(vv.ravel(), expected_sv.ravel(), default_places) def test_descriptor(self): - _, _, _, descpt = self.dp.eval_descriptor(self.coords, self.box, self.atype) + descpt = self.dp.eval_descriptor(self.coords, self.box, self.atype) expected_descpt = np.loadtxt(str(tests_path / "infer" / "deeppot_descpt.txt")) np.testing.assert_almost_equal(descpt.ravel(), expected_descpt.ravel()) From a599e6e2b951a5ef7be627416a8db925701f1d72 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 19 Feb 2022 16:29:50 -0500 Subject: [PATCH 05/11] merge duplicated codes --- deepmd/infer/deep_pot.py | 61 ++++++++++++++++++++++------------------ 1 file changed, 34 insertions(+), 27 deletions(-) diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index fe4d71c4d1..5ee3d06d39 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -1,5 +1,5 @@ import logging -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Callable import numpy as np from deepmd.common import make_default_mesh @@ -176,6 +176,32 @@ def get_dim_fparam(self) -> int: def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this DP.""" return self.daparam + + def _eval_func(self, inner_func: Callable) -> Callable: + """Wrapper method with auto batch size. + + Parameters + ---------- + inner_func : Callable + the method to be wrapped + + Returns + ------- + Callable + the wrapper + """ + if self.auto_batch_size is not None: + def eval_func(*args, **kwargs): + return self.auto_batch_size.execute_all(self._eval_inner, numb_test, natoms, *args, **kwargs) + else: + eval_func = self._eval_inner + return eval_func + + def _get_natoms_and_nframes(self, coords: np.ndarray, atom_types: List[int]) -> Tuple[int, int]: + natoms = len(atom_types) + coords = np.reshape(np.array(coords), [-1, natoms * 3]) + nframes = coords.shape[0] + return natoms, nframes def eval( self, @@ -232,15 +258,8 @@ def eval( The atomic virial. Only returned when atomic == True """ # reshape coords before getting shape - natoms = len(atom_types) - coords = np.reshape(np.array(coords), [-1, natoms * 3]) - numb_test = coords.shape[0] - if self.auto_batch_size is not None: - def eval_func(*args, **kwargs): - return self.auto_batch_size.execute_all(self._eval_inner, numb_test, natoms, *args, **kwargs) - else: - eval_func = self._eval_inner - output = eval_func(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) + natoms, numb_test = self._get_natoms_and_nframes(coords, atom_types) + output = self._eval_func(self._eval_inner)(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) if self.modifier_type is not None: if atomic: @@ -263,10 +282,9 @@ def _prepare_feed_dict( efield=None ): # standarize the shape of inputs + natoms, nframes = self._get_natoms_and_nframes(coords, atom_types) atom_types = np.array(atom_types, dtype = int).reshape([-1]) - natoms = atom_types.size coords = np.reshape(np.array(coords), [-1, natoms * 3]) - nframes = coords.shape[0] if cells is None: pbc = False # make cells to work around the requirement of pbc @@ -344,9 +362,7 @@ def _eval_inner( atomic=False, efield=None ): - natoms = atom_types.size - coords = np.reshape(np.array(coords), [-1, natoms * 3]) - nframes = coords.shape[0] + natoms, nframes = self._get_natoms_and_nframes(coords, atom_types) feed_dict_test = self._prepare_feed_dict(coords, cells, atom_types, fparam, aparam, efield) t_out = [self.t_energy, self.t_force, @@ -421,15 +437,8 @@ def eval_descriptor(self, descriptor Descriptors. """ - natoms = len(atom_types) - coords = np.reshape(np.array(coords), [-1, natoms * 3]) - numb_test = coords.shape[0] - if self.auto_batch_size is not None: - def eval_func(*args, **kwargs): - return self.auto_batch_size.execute_all(self._eval_descriptor_inner, numb_test, natoms, *args, **kwargs) - else: - eval_func = self._eval_descriptor_inner - return eval_func(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) + natoms, numb_test = self._get_natoms_and_nframes(coords, atom_types) + return self._eval_func(self._eval_descriptor_inner)(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) def _eval_descriptor_inner(self, coords: np.ndarray, @@ -439,9 +448,7 @@ def _eval_descriptor_inner(self, aparam: Optional[np.ndarray] = None, efield: Optional[np.ndarray] = None, ) -> np.array: - natoms = atom_types.size - coords = np.reshape(np.array(coords), [-1, natoms * 3]) - nframes = coords.shape[0] + natoms, nframes = self._get_natoms_and_nframes(coords, atom_types) feed_dict_test = self._prepare_feed_dict(coords, cells, atom_types, fparam, aparam, efield) descriptor, = run_sess(self.sess, [self.t_descriptor], feed_dict = feed_dict_test) return np.reshape(descriptor, [nframes, natoms, -1]) \ No newline at end of file From a8fba9406b7275c4538e409361de5497152319fe Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 19 Feb 2022 16:54:50 -0500 Subject: [PATCH 06/11] bugfix --- deepmd/infer/deep_pot.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index 5ee3d06d39..c87646f2dc 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -177,13 +177,17 @@ def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this DP.""" return self.daparam - def _eval_func(self, inner_func: Callable) -> Callable: + def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Callable: """Wrapper method with auto batch size. Parameters ---------- inner_func : Callable the method to be wrapped + numb_test: int + number of tests + natoms : int + number of atoms Returns ------- @@ -259,7 +263,7 @@ def eval( """ # reshape coords before getting shape natoms, numb_test = self._get_natoms_and_nframes(coords, atom_types) - output = self._eval_func(self._eval_inner)(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) + output = self._eval_func(self._eval_inner, numb_test, natoms)(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) if self.modifier_type is not None: if atomic: @@ -438,7 +442,7 @@ def eval_descriptor(self, Descriptors. """ natoms, numb_test = self._get_natoms_and_nframes(coords, atom_types) - return self._eval_func(self._eval_descriptor_inner)(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) + return self._eval_func(self._eval_descriptor_inner, numb_test, natoms)(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) def _eval_descriptor_inner(self, coords: np.ndarray, From e332de3038bf616bee5cbf466f86f0ca2c7b0b3b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 19 Feb 2022 17:10:33 -0500 Subject: [PATCH 07/11] imap --- deepmd/infer/deep_pot.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index c87646f2dc..7c4066cdf4 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -354,7 +354,7 @@ def _prepare_feed_dict( feed_dict_test[self.t_fparam] = np.reshape(fparam, [-1]) if self.has_aparam: feed_dict_test[self.t_aparam] = np.reshape(aparam, [-1]) - return feed_dict_test + return feed_dict_test, imap def _eval_inner( self, @@ -367,7 +367,7 @@ def _eval_inner( efield=None ): natoms, nframes = self._get_natoms_and_nframes(coords, atom_types) - feed_dict_test = self._prepare_feed_dict(coords, cells, atom_types, fparam, aparam, efield) + feed_dict_test, imap = self._prepare_feed_dict(coords, cells, atom_types, fparam, aparam, efield) t_out = [self.t_energy, self.t_force, self.t_virial] @@ -453,6 +453,6 @@ def _eval_descriptor_inner(self, efield: Optional[np.ndarray] = None, ) -> np.array: natoms, nframes = self._get_natoms_and_nframes(coords, atom_types) - feed_dict_test = self._prepare_feed_dict(coords, cells, atom_types, fparam, aparam, efield) + feed_dict_test, imap = self._prepare_feed_dict(coords, cells, atom_types, fparam, aparam, efield) descriptor, = run_sess(self.sess, [self.t_descriptor], feed_dict = feed_dict_test) - return np.reshape(descriptor, [nframes, natoms, -1]) \ No newline at end of file + return self.reverse_map(np.reshape(descriptor, [nframes, natoms, -1]), imap) From 11d26e58810dd86a24c3db9c5595a84b059c5411 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 19 Feb 2022 17:32:55 -0500 Subject: [PATCH 08/11] Update deep_pot.py --- deepmd/infer/deep_pot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index 7c4066cdf4..293923c6e4 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -442,7 +442,7 @@ def eval_descriptor(self, Descriptors. """ natoms, numb_test = self._get_natoms_and_nframes(coords, atom_types) - return self._eval_func(self._eval_descriptor_inner, numb_test, natoms)(coords, cells, atom_types, fparam = fparam, aparam = aparam, atomic = atomic, efield = efield) + return self._eval_func(self._eval_descriptor_inner, numb_test, natoms)(coords, cells, atom_types, fparam = fparam, aparam = aparam, efield = efield) def _eval_descriptor_inner(self, coords: np.ndarray, From 30628805f2d37c37d9e3bb9b1f7c3a0799444669 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 19 Feb 2022 17:47:43 -0500 Subject: [PATCH 09/11] do not return tuple --- deepmd/infer/deep_pot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index 293923c6e4..3c85c01e51 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -442,7 +442,8 @@ def eval_descriptor(self, Descriptors. """ natoms, numb_test = self._get_natoms_and_nframes(coords, atom_types) - return self._eval_func(self._eval_descriptor_inner, numb_test, natoms)(coords, cells, atom_types, fparam = fparam, aparam = aparam, efield = efield) + descriptor, = self._eval_func(self._eval_descriptor_inner, numb_test, natoms)(coords, cells, atom_types, fparam = fparam, aparam = aparam, efield = efield) + return descriptor def _eval_descriptor_inner(self, coords: np.ndarray, From 569c61aa93d21cbf8cfcbe919e757bf5b8d68b36 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 19 Feb 2022 19:43:23 -0500 Subject: [PATCH 10/11] fix typo --- deepmd/infer/deep_pot.py | 6 +++--- source/tests/infer/deeppot_descpt.txt | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index 3c85c01e51..bd8ae1491b 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -196,9 +196,9 @@ def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Calla """ if self.auto_batch_size is not None: def eval_func(*args, **kwargs): - return self.auto_batch_size.execute_all(self._eval_inner, numb_test, natoms, *args, **kwargs) + return self.auto_batch_size.execute_all(inner_func, numb_test, natoms, *args, **kwargs) else: - eval_func = self._eval_inner + eval_func = inner_func return eval_func def _get_natoms_and_nframes(self, coords: np.ndarray, atom_types: List[int]) -> Tuple[int, int]: @@ -442,7 +442,7 @@ def eval_descriptor(self, Descriptors. """ natoms, numb_test = self._get_natoms_and_nframes(coords, atom_types) - descriptor, = self._eval_func(self._eval_descriptor_inner, numb_test, natoms)(coords, cells, atom_types, fparam = fparam, aparam = aparam, efield = efield) + descriptor = self._eval_func(self._eval_descriptor_inner, numb_test, natoms)(coords, cells, atom_types, fparam = fparam, aparam = aparam, efield = efield) return descriptor def _eval_descriptor_inner(self, diff --git a/source/tests/infer/deeppot_descpt.txt b/source/tests/infer/deeppot_descpt.txt index 1ea18e9a02..d757dc7d08 100644 --- a/source/tests/infer/deeppot_descpt.txt +++ b/source/tests/infer/deeppot_descpt.txt @@ -1,6 +1,6 @@ 1.321519961572583446e+00 1.371086996432463234e+00 1.343885362068737654e+00 7.030511147133322591e-01 1.371086996432463234e+00 1.522728277961870713e+00 1.495890684109977053e+00 7.457809775555739318e-01 1.343885362068737654e+00 1.495890684109977053e+00 1.469632964943500042e+00 7.315484240279518380e-01 7.030511147133322591e-01 7.457809775555739318e-01 7.315484240279518380e-01 3.769423714838210926e-01 1.397448213242690862e+00 1.406307843425486315e+00 1.376945776785887476e+00 7.364224033531289182e-01 1.349677138676806498e+00 1.501489244191841710e+00 1.475108126068265024e+00 7.345932819566267646e-01 1.428791111325009577e+00 1.527566663819698745e+00 1.498782670338390632e+00 7.675267547771572607e-01 9.945305114958860049e-01 1.119429391438308219e+00 1.100175320069580742e+00 5.435256973435016459e-01 -1.666120192454127125e+00 1.691592960101290677e+00 1.660848078254600457e+00 9.148871223997685487e-01 1.691592960101290677e+00 1.807557128537562008e+00 1.777443773744786570e+00 9.384841989100006776e-01 1.660848078254600457e+00 1.777443773744786570e+00 1.747912040379999921e+00 9.217134916614336815e-01 9.148871223997685487e-01 9.384841989100006776e-01 9.217134916614336815e-01 5.034483776550070511e-01 1.784463156540595508e+00 1.770428626353183876e+00 1.736989815040471008e+00 9.755029855694796748e-01 1.671216271483192850e+00 1.786798962823082038e+00 1.757059913333324896e+00 9.272948493374369994e-01 1.787952366477808974e+00 1.855120330291613273e+00 1.822612777927184124e+00 9.860466542466375106e-01 1.240466553399573124e+00 1.335597814608440181e+00 1.313636073416146965e+00 6.893007932368097057e-01 1.857190745941245336e+00 1.712653950602609498e+00 1.669033760849789605e+00 9.115932216522792952e-01 1.712653950602609498e+00 1.774801034363276520e+00 1.737803632150541455e+00 8.652896470799232853e-01 1.669033760849789605e+00 1.737803632150541455e+00 1.701914413913524937e+00 8.444842527320339798e-01 9.115932216522792952e-01 8.652896470799232853e-01 8.444842527320339798e-01 4.518552761259164718e-01 2.047343778537022985e+00 1.803503653609380475e+00 1.754050188943322208e+00 9.944616692818087911e-01 1.654654031531220149e+00 1.723583984166921823e+00 1.688052380342037084e+00 8.375386196249569037e-01 1.904925986059241572e+00 1.845391378434923402e+00 1.802137505101731207e+00 9.463545327018678677e-01 1.195956664712712669e+00 1.272969241553762787e+00 1.247794584173576471e+00 6.091661698149348769e-01 1.729318681716640826e+00 1.620316206148274540e+00 1.578053322854602758e+00 8.384315939852660104e-01 1.620316206148274540e+00 1.676382173662613884e+00 1.639223989973451090e+00 7.989806603154945286e-01 1.578053322854602758e+00 1.639223989973451090e+00 1.603163928836319085e+00 7.787623774594047976e-01 8.384315939852660104e-01 7.989806603154945286e-01 7.787623774594047976e-01 4.083319738392942599e-01 1.895061711055783693e+00 1.704719175253225805e+00 1.657289874036961708e+00 9.129654774778642734e-01 1.563778651834606404e+00 1.623994208975480413e+00 1.588269534429824548e+00 7.718070758411798016e-01 1.784871498937309786e+00 1.743282784333909374e+00 1.700760863129526790e+00 8.714582672042692213e-01 1.128335098273277159e+00 1.191704502556441003e+00 1.166245241817210010e+00 5.588000601315349369e-01 +1.666120192454127125e+00 1.691592960101290677e+00 1.660848078254600457e+00 9.148871223997685487e-01 1.691592960101290677e+00 1.807557128537562008e+00 1.777443773744786570e+00 9.384841989100006776e-01 1.660848078254600457e+00 1.777443773744786570e+00 1.747912040379999921e+00 9.217134916614336815e-01 9.148871223997685487e-01 9.384841989100006776e-01 9.217134916614336815e-01 5.034483776550070511e-01 1.784463156540595508e+00 1.770428626353183876e+00 1.736989815040471008e+00 9.755029855694796748e-01 1.671216271483192850e+00 1.786798962823082038e+00 1.757059913333324896e+00 9.272948493374369994e-01 1.787952366477808974e+00 1.855120330291613273e+00 1.822612777927184124e+00 9.860466542466375106e-01 1.240466553399573124e+00 1.335597814608440181e+00 1.313636073416146965e+00 6.893007932368097057e-01 2.311436366095430905e+00 2.106967028437471523e+00 2.059250658547694179e+00 1.173357457272029780e+00 2.106967028437471523e+00 2.099685185993240832e+00 2.058804265227755614e+00 1.088540185250911785e+00 2.059250658547694179e+00 2.058804265227755614e+00 2.018985089841780045e+00 1.064745413335918212e+00 1.173357457272029780e+00 1.088540185250911785e+00 1.064745413335918212e+00 5.982334090189560527e-01 2.562378776048241047e+00 2.257458645273420217e+00 2.203404394497280983e+00 1.292393223640032840e+00 2.047955802464616948e+00 2.047105830421261707e+00 2.007521807924319557e+00 1.058951143245145721e+00 2.363728390448381234e+00 2.235568922662932501e+00 2.187964141514865180e+00 1.208516329724113270e+00 1.489466008127536600e+00 1.511668501789738439e+00 1.483254366306933525e+00 7.727514556258590073e-01 2.129030587855572421e+00 1.947158041374798865e+00 1.900213202229163123e+00 1.062720944427532066e+00 1.947158041374798865e+00 1.974765866918151369e+00 1.934939878705544514e+00 9.958975115167185699e-01 1.900213202229163123e+00 1.934939878705544514e+00 1.896230134293759750e+00 9.730461889970962730e-01 1.062720944427532066e+00 9.958975115167185699e-01 9.730461889970962730e-01 5.346713914730781836e-01 2.358302986474245078e+00 2.072936997733268782e+00 2.019613016300256803e+00 1.166931775412306971e+00 1.889259211229606494e+00 1.923865183776938270e+00 1.885410865441944139e+00 9.676424835630113019e-01 2.180008917694376436e+00 2.081759028316970461e+00 2.035118746221280528e+00 1.099152505661615153e+00 1.366637640485368177e+00 1.418420572687418169e+00 1.391094495316195001e+00 7.036614101584164338e-01 From a55f10bb1b0dc37580bbddfbd0f51b23f3d0cd39 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 20 Feb 2022 03:42:46 -0500 Subject: [PATCH 11/11] make the code more readable --- deepmd/infer/deep_pot.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index bd8ae1491b..131999727a 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -269,10 +269,12 @@ def eval( if atomic: raise RuntimeError('modifier does not support atomic modification') me, mf, mv = self.dm.eval(coords, cells, atom_types) + output = list(output) # tuple to list e, f, v = output[:3] - e += me.reshape(e.shape) - f += mf.reshape(f.shape) - v += mv.reshape(v.shape) + output[0] += me.reshape(e.shape) + output[1] += mf.reshape(f.shape) + output[2] += mv.reshape(v.shape) + output = tuple(output) return output def _prepare_feed_dict(