Skip to content
3 changes: 3 additions & 0 deletions dmff/admp/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def distribute_multipoles(multipoles, index):
def distribute_dispcoeff(c_list, index):
return c_list[index]

@jit_condition(static_argnums=())
def distribute_matrix(multipoles,index1,index2):
return multipoles[index1,index2]

def generate_pairwise_interaction(pair_int_kernel, covalent_map, static_args):
'''
Expand Down
7 changes: 5 additions & 2 deletions dmff/admp/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np
import warnings
from collections import defaultdict
import jax.numpy as jnp
from dmff.admp.multipole import convert_cart2harm

def read_atom_line(line_full):
"""
Expand Down Expand Up @@ -326,7 +328,8 @@ def read_xml(fileobj):
set_axis_type(atomTemplates)

return atomTemplates, residueTemplates



class Atom:

def __init__(self, serial, name, resName, resSeq, position, charge, ) -> None:
Expand Down Expand Up @@ -474,4 +477,4 @@ def assemble_covalent(residueDicts, natoms):
covalent_map[c][pp] = dr

return covalent_map

8 changes: 5 additions & 3 deletions dmff/admp/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from dmff.admp.pairwise import (
distribute_scalar,
distribute_v3,
distribute_multipoles
distribute_multipoles,
distribute_matrix
)


Expand Down Expand Up @@ -792,7 +793,8 @@ def pme_real(positions, box, pairs,
r2 = distribute_v3(positions, pairs[:, 1])
Q_extendi = distribute_multipoles(Q_global, pairs[:, 0])
Q_extendj = distribute_multipoles(Q_global, pairs[:, 1])
nbonds = covalent_map[pairs[:, 0], pairs[:, 1]]
nbonds = distribute_matrix(covalent_map,pairs[:, 0],pairs[:, 1])
#nbonds = covalent_map[pairs[:, 0], pairs[:, 1]]
indices = nbonds-1
mscales = distribute_scalar(mScales, indices)
mscales = mscales * buffer_scales
Expand Down Expand Up @@ -896,4 +898,4 @@ def pol_penalty(U_ind, pol):
# this is to remove the singularity when pol=0
pol_pi = trim_val_0(pol)
# pol_pi = pol/(jnp.exp((-pol+1e-08)*1e10)+1) + 1e-08/(jnp.exp((pol-1e-08)*1e10)+1)
return jnp.sum(0.5/pol_pi*(U_ind**2).T) * DIELECTRIC
return jnp.sum(0.5/pol_pi*(U_ind**2).T) * DIELECTRIC
1 change: 0 additions & 1 deletion dmff/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
self.map_atomtype = map_atomtype
# build covalent map
covalent_map = build_covalent_map(data, 6)

# here box is only used to setup ewald parameters, no need to be differentiable
a, b, c = system.getDefaultPeriodicBoxVectors()
box = jnp.array([a._value, b._value, c._value]) * 10
Expand Down
Loading