Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions dmff/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,16 @@ def render(self, filename):
def getParameters(self):
return self.paramtree

def updateParameters(self, paramtree):
def update_iter(node, ref):
for key in ref:
if isinstance(ref[key], dict):
update_iter(node[key], ref[key])
else:
node[key] = ref[key]

update_iter(self.paramtree, paramtree)


class HarmonicBondJaxGenerator:
def __init__(self, ff):
Expand Down Expand Up @@ -1996,6 +2006,9 @@ def extract(self):
prop_phase[f"{npred}"])
self.paramtree[self.name]["prop_k"][f"{npred}"] = jnp.array(
prop_k[f"{npred}"])
if self.max_pred_prop == 0:
del self.paramtree[self.name]["prop_phase"]
del self.paramtree[self.name]["prop_k"]

# impropers
impr_phase = defaultdict(list)
Expand All @@ -2021,14 +2034,17 @@ def extract(self):
impr_phase[f"{npred}"])
self.paramtree[self.name]["impr_k"][f"{npred}"] = jnp.array(
impr_k[f"{npred}"])
if self.max_pred_impr == 0:
del self.paramtree[self.name]["impr_phase"]
del self.paramtree[self.name]["impr_k"]

def overwrite(self):
propers = self.fftree.get_nodes("PeriodicTorsionForce/Proper")
impropers = self.fftree.get_nodes("PeriodicTorsionForce/Improper")
prop_data = [{} for _ in propers]
impr_data = [{} for _ in impropers]
# make propers
for periodicity in range(1, self.max_pred_prop):
for periodicity in range(1, self.max_pred_prop+1):
nterms = len(
self.paramtree[self.name][f"prop_phase"][f"{periodicity}"])
for nitem in range(nterms):
Expand All @@ -2040,10 +2056,11 @@ def overwrite(self):
order = self.meta[f"prop_order"][f"{periodicity}"][nitem]
prop_data[nodeidx][f"phase{order}"] = phase
prop_data[nodeidx][f"k{order}"] = k
self.fftree.set_node("PeriodicTorsionForce/Proper", prop_data)
if "prop_phase" in self.paramtree[self.name]:
self.fftree.set_node("PeriodicTorsionForce/Proper", prop_data)

# make impropers
for periodicity in range(1, self.max_pred_impr):
for periodicity in range(1, self.max_pred_impr+1):
nterms = len(
self.paramtree[self.name][f"impr_phase"][f"{periodicity}"])
for nitem in range(nterms):
Expand All @@ -2055,7 +2072,8 @@ def overwrite(self):
order = self.meta[f"impr_order"][f"{periodicity}"][nitem]
impr_data[nodeidx][f"phase{order}"] = phase
impr_data[nodeidx][f"k{order}"] = k
self.fftree.set_node("PeriodicTorsionForce/Improper", impr_data)
if "impr_phase" in self.paramtree[self.name]:
self.fftree.set_node("PeriodicTorsionForce/Improper", impr_data)

def createForce(self, sys, data, nonbondedMethod, nonbondedCutoff, args):
proper_matcher = TypeMatcher(self.fftree,
Expand Down
3 changes: 1 addition & 2 deletions dmff/classical/inter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales):

eps_m1 = jnp.repeat(epsilon.reshape((-1, 1)), epsilon.shape[0], axis=1)
eps_m2 = eps_m1.T
eps_mat = jnp.sqrt(eps_m1 * eps_m2)
eps_mat = jnp.sqrt(eps_m1 * eps_m2 + 1e-12)
sig_m1 = jnp.repeat(sigma.reshape((-1, 1)), sigma.shape[0], axis=1)
sig_m2 = sig_m1.T
sig_mat = (sig_m1 + sig_m2) * 0.5
Expand All @@ -85,7 +85,6 @@ def get_energy(positions, box, pairs, epsilon, sigma, epsfix, sigfix, mscales):
eps_scale = eps * mscale_pair

E_inter = get_LJ_energy(dr_vec, sig, eps_scale, box)

return jnp.sum(E_inter * mask)

return get_energy
Expand Down