From 2956d89b1ba364666e20fd1dc1ffbcf6fb0510eb Mon Sep 17 00:00:00 2001 From: WangXinyan940 Date: Wed, 6 Jul 2022 19:36:26 +0800 Subject: [PATCH 1/2] fix(classical): avoid creating empty dict in impr parameters --- dmff/api.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/dmff/api.py b/dmff/api.py index 7c2e11a7d..1750d1f26 100644 --- a/dmff/api.py +++ b/dmff/api.py @@ -1823,6 +1823,16 @@ def render(self, filename): def getParameters(self): return self.paramtree + def updateParameters(self, paramtree): + def update_iter(node, ref): + for key in ref: + if isinstance(ref[key], dict): + update_iter(node[key], ref[key]) + else: + node[key] = ref[key] + + update_iter(self.paramtree, paramtree) + class HarmonicBondJaxGenerator: def __init__(self, ff): @@ -1996,6 +2006,9 @@ def extract(self): prop_phase[f"{npred}"]) self.paramtree[self.name]["prop_k"][f"{npred}"] = jnp.array( prop_k[f"{npred}"]) + if self.max_pred_prop == 0: + del self.paramtree[self.name]["prop_phase"] + del self.paramtree[self.name]["prop_k"] # impropers impr_phase = defaultdict(list) @@ -2021,6 +2034,9 @@ def extract(self): impr_phase[f"{npred}"]) self.paramtree[self.name]["impr_k"][f"{npred}"] = jnp.array( impr_k[f"{npred}"]) + if self.max_pred_impr == 0: + del self.paramtree[self.name]["impr_phase"] + del self.paramtree[self.name]["impr_k"] def overwrite(self): propers = self.fftree.get_nodes("PeriodicTorsionForce/Proper") @@ -2028,7 +2044,7 @@ def overwrite(self): prop_data = [{} for _ in propers] impr_data = [{} for _ in impropers] # make propers - for periodicity in range(1, self.max_pred_prop): + for periodicity in range(1, self.max_pred_prop+1): nterms = len( self.paramtree[self.name][f"prop_phase"][f"{periodicity}"]) for nitem in range(nterms): @@ -2040,10 +2056,11 @@ def overwrite(self): order = self.meta[f"prop_order"][f"{periodicity}"][nitem] prop_data[nodeidx][f"phase{order}"] = phase prop_data[nodeidx][f"k{order}"] = k - self.fftree.set_node("PeriodicTorsionForce/Proper", prop_data) + if "prop_phase" in self.paramtree[self.name]: + self.fftree.set_node("PeriodicTorsionForce/Proper", prop_data) # make impropers - for periodicity in range(1, self.max_pred_impr): + for periodicity in range(1, self.max_pred_impr+1): nterms = len( self.paramtree[self.name][f"impr_phase"][f"{periodicity}"]) for nitem in range(nterms): @@ -2055,7 +2072,8 @@ def overwrite(self): order = self.meta[f"impr_order"][f"{periodicity}"][nitem] impr_data[nodeidx][f"phase{order}"] = phase impr_data[nodeidx][f"k{order}"] = k - self.fftree.set_node("PeriodicTorsionForce/Improper", impr_data) + if "impr_phase" in self.paramtree[self.name]: + self.fftree.set_node("PeriodicTorsionForce/Improper", impr_data) def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args): proper_matcher = TypeMatcher(self.fftree, From 4055f117d076d7b81ce81f1fe822e301f919be71 Mon Sep 17 00:00:00 2001 From: WangXinyan940 Date: Wed, 6 Jul 2022 19:36:56 +0800 Subject: [PATCH 2/2] fix(inter): add 1e-12 as eps value in jnp.sqrt function --- dmff/classical/inter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dmff/classical/inter.py b/dmff/classical/inter.py index 66830c579..2ef172870 100644 --- a/dmff/classical/inter.py +++ b/dmff/classical/inter.py @@ -63,7 +63,7 @@ def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales): eps_m1 = jnp.repeat(epsilon.reshape((-1, 1)), epsilon.shape[0], axis=1) eps_m2 = eps_m1.T - eps_mat = jnp.sqrt(eps_m1 * eps_m2) + eps_mat = jnp.sqrt(eps_m1 * eps_m2 + 1e-12) sig_m1 = jnp.repeat(sigma.reshape((-1, 1)), sigma.shape[0], axis=1) sig_m2 = sig_m1.T sig_mat = (sig_m1 + sig_m2) * 0.5 @@ -85,7 +85,6 @@ def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales): eps_scale = eps * mscale_pair E_inter = get_LJ_energy(dr_vec, sig, eps_scale, box) - return jnp.sum(E_inter * mask) return get_energy