Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 177 additions & 97 deletions dmff/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from jax_md import space, partition
from jax import grad
import linecache
import sys


def get_line_context(file_path, line_number):
Expand Down Expand Up @@ -40,81 +41,6 @@ def build_covalent_map(data, max_neighbor):
return covalent_map


def set_axis_type(map_atomtypes, types, params):

ZThenX = 0
Bisector = 1
ZBisect = 2
ThreeFold = 3
Zonly = 4
NoAxisType = 5
LastAxisTypeIndex = 6
kStrings = ["kz", "kx", "ky"]
axisIndices = []
axisTypes = []

for i in map_atomtypes:
atomType = types[i]

kIndices = [atomType]

for kString in kStrings:
kString_value = params[kString][i]
if kString_value != "":
kIndices.append(kString_value)
axisIndices.append(kIndices)

# set axis type

kIndicesLen = len(kIndices)

if kIndicesLen > 3:
ky = kIndices[3]
kyNegative = False
if ky.startswith("-"):
ky = kIndices[3] = ky[1:]
kyNegative = True
else:
ky = ""

if kIndicesLen > 2:
kx = kIndices[2]
kxNegative = False
if kx.startswith("-"):
kx = kIndices[2] = kx[1:]
kxNegative = True
else:
kx = ""

if kIndicesLen > 1:
kz = kIndices[1]
kzNegative = False
if kz.startswith("-"):
kz = kIndices[1] = kz[1:]
kzNegative = True
else:
kz = ""

while len(kIndices) < 4:
kIndices.append("")

axisType = ZThenX
if not kz:
axisType = NoAxisType
if kz and not kx:
axisType = Zonly
if kz and kzNegative or kx and kxNegative:
axisType = Bisector
if kx and kxNegative and ky and kyNegative:
axisType = ZBisect
if kz and kzNegative and kx and kxNegative and ky and kyNegative:
axisType = ThreeFold

axisTypes.append(axisType)

return np.array(axisTypes), np.array(axisIndices)


class ADMPDispGenerator:
def __init__(self, hamiltonian):
self.ff = hamiltonian
Expand Down Expand Up @@ -148,6 +74,7 @@ def parseElement(element, hamiltonian):
mScales = []
for i in range(2, 7):
mScales.append(float(element.attrib["mScale1%d" % i]))
mScales.append(1.0)
generator.params["mScales"] = mScales
for atomtype in element.findall("Atom"):
generator.registerAtomType(atomtype.attrib)
Expand Down Expand Up @@ -279,7 +206,7 @@ def registerAtomType(self, atom: dict):
if kString in atom:
self.kStrings[kString].append(atom.pop(kString))
else:
self.kStrings[kString].append("")
self.kStrings[kString].append("0")

for k, v in atom.items():
self._input_params[k].append(float(v))
Expand All @@ -300,6 +227,11 @@ def parseElement(element, hamiltonian):
generator.params["dScales"].append(
float(element.attrib["dScale1%d" % i]))

# make sure the last digit is 1.0
generator.params['mScales'].append(1.0)
generator.params['pScales'].append(1.0)
generator.params['dScales'].append(1.0)

if element.findall('Polarize'):
generator.lpol = True
else:
Expand Down Expand Up @@ -368,7 +300,6 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
args):

n_atoms = len(data.atoms)
# build index map
map_atomtype = np.zeros(n_atoms, dtype=int)

for i in range(n_atoms):
Expand All @@ -386,28 +317,177 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff,
covalent_map = build_covalent_map(data, 6)

# build intra-molecule axis
# the following code is the direct transplant of forcefield.py in openmm 7.4.0

if self.lmax > 0:
self.axis_types, self.axis_indices = set_axis_type(
map_atomtype, self.types, self.kStrings)
map_axis_indices = []
# map axis_indices
for i in range(n_atoms):
catom = data.atoms[i]
residue = catom.residue._atoms
atom_indices = [
index if index != "" else -1
for index in self.axis_indices[i][1:]
]
for atom in residue:
if atom == catom:
continue
for i in range(len(atom_indices)):
if atom_indices[i] == data.atomType[atom]:
atom_indices[i] = atom.index
break
map_axis_indices.append(atom_indices)

self.axis_indices = np.array(map_axis_indices)
# setting up axis_indices and axis_type
ZThenX = 0
Bisector = 1
ZBisect = 2
ThreeFold = 3
Zonly = 4
NoAxisType = 5
LastAxisTypeIndex = 6

self.axis_types = []
self.axis_indices = []
for i_atom in range(n_atoms):
atom = data.atoms[i_atom]
t = data.atomType[atom]
# if t is in type list?
if t in self.types:
itypes = np.where(self.types == t)[0]
hit = 0
# try to assign multipole parameters via only 1-2 connected atoms
for itype in itypes:
if hit != 0:
break
kz = int(self.kStrings['kz'][itype])
kx = int(self.kStrings['kx'][itype])
ky = int(self.kStrings['ky'][itype])
neighbors = np.where(covalent_map[i_atom] == 1)[0]
zaxis = -1
xaxis = -1
yaxis = -1
for z_index in neighbors:
if hit != 0:
break
z_type = int(data.atomType[data.atoms[z_index]])
if z_type == abs(kz): # find the z atom, start searching for x
for x_index in neighbors:
if x_index == z_index or hit != 0:
continue
x_type = int(data.atomType[data.atoms[x_index]])
if x_type == abs(kx): # find the x atom, start searching for y
if ky == 0:
zaxis = z_index
xaxis = x_index
# cannot ditinguish x and z? use the smaller index for z, and the larger index for x
if (x_type == z_type and xaxis < zaxis):
swap = z_axis
z_axis = x_axis
x_axis = swap
# otherwise, try to see if we can find an even smaller index for x?
else:
for x_index in neighbors:
x_type1 = int(data.atomType[data.atoms[x_index]])
if x_type1 == abs(kx) and x_index != z_index and x_index < xaxis:
xaxis = x_index
hit = 1 # hit, finish matching
matched_itype = itype
else:
for y_index in neighbors:
if (y_index == z_index or y_index == x_index or hit != 0):
continue
y_type = int(data.atomType[data.atoms[y_index]])
if y_type == abs(ky):
zaxis = z_index
xaxis = x_index
yaxis = y_index
hit = 2
matched_itype = itype
# assign multipole parameters via 1-2 and 1-3 connected atoms
for itype in itypes:
if hit != 0:
break
kz = int(self.kStrings['kz'][itype])
kx = int(self.kStrings['kx'][itype])
ky = int(self.kStrings['ky'][itype])
neighbors_1st = np.where(covalent_map[i_atom] == 1)[0]
neighbors_2nd = np.where(covalent_map[i_atom] == 2)[0]
zaxis = -1
xaxis = -1
yaxis = -1
for z_index in neighbors_1st:
if hit != 0:
break
z_type = int(data.atomType[data.atoms[z_index]])
if z_type == abs(kz):
for x_index in neighbors_2nd:
if x_index == z_index or hit != 0:
continue
x_type = int(data.atomType[data.atoms[x_index]])
# we ask x to be in 2'nd neighbor, and x is z's neighbor
if x_type == abs(kx) and covalent_map[z_index, x_index] == 1:
if ky == 0:
zaxis = z_index
xaxis = x_index
# select smallest x index
for x_index in neighbors_2nd:
x_type1 = int(data.atomType[data.atoms[x_index]])
if x_type1 == abs(kx) and x_index != z_index and covalent_map[x_index, z_index] == 1 and x_index < xaxis:
xaxis = x_index
hit = 3
matched_itype = itype
else:
for y_index in neighbors_2nd:
if y_index == z_index or y_index == x_index or hit != 0:
continue
y_type = int(data.atomType[data.atoms[y_index]])
if y_type == abs(ky) and covalent_map[y_index, z_index] == 1:
zaxis = z_index
xaxis = x_index
yaxis = y_index
hit = 4
matched_itype = itype
# assign multipole parameters via only a z-defining atom
for itype in itypes:
if hit != 0:
break
kz = int(self.kStrings['kz'][itype])
kx = int(self.kStrings['kx'][itype])
zaxis = -1
xaxis = -1
yaxis = -1
neighbors = np.where(covalent_map[i_atom] == 1)[0]
for z_index in neighbors:
if hit != 0:
break
z_type = int(data.atomType[data.atoms[z_index]])
if kx == 0 and z_type == abs(kz):
zaxis = z_index
hit = 5
matched_itype = itype
# assign multipole parameters via no connected atoms
for itype in itypes:
if hit != 0:
break
kz = int(self.kStrings['kz'][itype])
zaxis = -1
xaxis = -1
yaxis = -1
if kz == 0:
hit = 6
matched_itype = itype
# add particle if there was a hit
if hit != 0:
map_atomtype[i_atom] = matched_itype
self.axis_indices.append([zaxis, xaxis, yaxis])

kz = int(self.kStrings['kz'][matched_itype])
kx = int(self.kStrings['kx'][matched_itype])
ky = int(self.kStrings['ky'][matched_itype])
axisType = ZThenX
if (kz == 0):
axisType = NoAxisType
if (kz != 0 and kx == 0):
axisType = ZOnly
if (kz < 0 or kx < 0):
axisType = Bisector
if (kx < 0 and ky < 0):
axisType = ZBisect
if (kz < 0 and kx < 0 and ky < 0):
axisType = ThreeFold
self.axis_types.append(axisType)

else:
sys.exit('Atom %d not matched in forcefield!'%i_atom)

else:
sys.exit('Atom %d not matched in forcefield!'%i_atom)
self.axis_indices = np.array(self.axis_indices)
self.axis_types = np.array(self.axis_types)
else:
self.axis_types = None
self.axis_indices = None
Expand Down