diff --git a/dmff/common/nblist.py b/dmff/common/nblist.py index c57b481be..1b06631ec 100644 --- a/dmff/common/nblist.py +++ b/dmff/common/nblist.py @@ -44,7 +44,7 @@ def allocate(self, positions: jnp.ndarray, box: Optional[jnp.ndarray] = None): self.nblist = self.neighborlist_fn.allocate(positions) else: self.update(positions, box) - return self.nblist + return self.pairs def update(self, positions: jnp.ndarray, box: Optional[jnp.ndarray] = None): """ A function to update a neighbor list given a new set of positions and a previously allocated neighbor list. @@ -59,7 +59,7 @@ def update(self, positions: jnp.ndarray, box: Optional[jnp.ndarray] = None): self.nblist = self.nblist.update(positions) else: self.nblist = self.nblist.update(positions, box) - return self.nblist + return self.pairs @property def pairs(self): diff --git a/dmff/generators/classical.py b/dmff/generators/classical.py index b4584ba9a..e88c909d2 100644 --- a/dmff/generators/classical.py +++ b/dmff/generators/classical.py @@ -97,6 +97,7 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): map_param = np.array(map_param, dtype=int) bforce = HarmonicBondJaxForce(map_atom1, map_atom2, map_param) + self._force_latest = bforce def potential_fn(positions, box, pairs, params): return bforce.get_energy(positions, box, pairs, @@ -163,6 +164,7 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): aforce = HarmonicAngleJaxForce(map_atom1, map_atom2, map_atom3, map_param) + self._force_latest = aforce def potential_fn(positions, box, pairs, params): return aforce.get_energy(positions, box, pairs, @@ -384,6 +386,8 @@ def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): jnp.array(map_impr_param[p], dtype=int), p) for p in range(1, self.max_pred_impr + 1) ] + self._props_latest = props + self._imprs_latest = imprs def potential_fn(positions, box, pairs, params): prop_sum = sum([ @@ -822,8 +826,8 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, map_lj = jnp.array(maps["sigma"]) - ifType = len(self.fftree.get_attribs("LennardJonesForce/Atom", - "type")) != 0 + ifType = len([i for i in self.fftree.get_attribs("LennardJonesForce/Atom", + "type") if i is not None]) != 0 if ifType: atom_labels = self.fftree.get_attribs("LennardJonesForce/Atom", "type")