Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dmff/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
import dmff.settings
from dmff.common.nblist import NeighborList
from .settings import *
from .common.nblist import NeighborList
from .api import Hamiltonian
185 changes: 173 additions & 12 deletions dmff/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from dmff.utils import isinstance_jnp
from .admp.disp_pme import ADMPDispPmeForce
from .admp.multipole import convert_cart2harm
from .admp.multipole import convert_cart2harm, convert_harm2cart
from .admp.pairwise import TT_damping_qq_c6_kernel, generate_pairwise_interaction
from .admp.pairwise import slater_disp_damping_kernel, slater_sr_kernel, TT_damping_qq_kernel
from .admp.pme import ADMPPmeForce
Expand All @@ -30,9 +30,26 @@
CoulReactionFieldForce,
)
import sys
from copy import deepcopy


class XMLNodeInfo:

@staticmethod
def to_str(value)->str:
""" convert value to string if it can
"""
if isinstance(value, str):
return value
elif isinstance(value, (jnp.ndarray, np.ndarray)):
if value.ndim == 0:
return str(value)
else:
return str(value[0])
elif isinstance(value, list):
return value[0] # strip [] of value
else:
return str(value)

class XMLElementInfo:

Expand All @@ -41,29 +58,49 @@ def __init__(self, name):
self.attributes = {}

def addAttribute(self, key, value):
self.attributes[key] = value
self.attributes[key] = XMLNodeInfo.to_str(value)

def __repr__(self):
return f'<{self.name} {" ".join([f"{k}={v}" for k, v in self.attributes.items()])}>'

def __getitem__(self, name):
return self.attributes[name]


def __init__(self, name):
self.name = name
self.attributes = {}
self.elements = []

def __getitem__(self, name):
if isinstance(name, str):
return self.attributes[name]
elif isinstance(name, int):
return self.elements[name]


def addAttribute(self, key, value):
self.attributes[key] = value
self.attributes[key] = XMLNodeInfo.to_str(value)


def addElement(self, name, info):
element = self.XMLElementInfo(name)
for k, v in info.items():
element.addAttribute(k, v)
self.elements.append(element)
self.elements.append(element)


def modResidue(self, residue, atom, key, value):
pass

def __repr__(self):
# tricy string formatting
left = f'<{self.name} {" ".join([f"{k}={v}" for k, v in self.attributes.items()])}> \n\t'
right = f'<\\{self.name}>'
content = '\n\t'.join([repr(e) for e in self.elements])
return left + content + '\n' + right



def get_line_context(file_path, line_number):
return linecache.getline(file_path, line_number).strip()
Expand Down Expand Up @@ -202,7 +239,18 @@ def getJaxPotential(self):

def renderXML(self):
# generate xml force field file
pass
finfo = XMLNodeInfo('ADMPDispForce')
finfo.addAttribute('mScale12', self.params["mScales"][0])
finfo.addAttribute('mScale13', self.params["mScales"][1])
finfo.addAttribute('mScale14', self.params["mScales"][2])
finfo.addAttribute('mScale15', self.params["mScales"][3])
finfo.addAttribute('mScale16', self.params["mScales"][4])

for i in range(len(self.types)):
ainfo = {'type': self.types[i], 'A': self.params["A"][i], 'B': self.params["B"][i], 'Q': self.params["Q"][i], 'C6': self.params["C6"][i], 'C8': self.params["C8"][i], 'C10': self.params["C10"][i]}
finfo.addElement('Atom', ainfo)

return finfo

# register all parsers
app.forcefield.parsers["ADMPDispForce"] = ADMPDispGenerator.parseElement
Expand Down Expand Up @@ -700,6 +748,7 @@ def parseElement(element, hamiltonian):
generator.types = np.array(generator.types)

n_atoms = len(element.findall("Atom"))
generator.n_atoms = n_atoms

# map atom multipole moments
if generator.lmax == 0:
Expand Down Expand Up @@ -1041,7 +1090,41 @@ def getJaxPotential(self):
return self._jaxPotential

def renderXML(self):
pass
# <ADMPPmeForce>

finfo = XMLNodeInfo('ADMPPmeForce')
finfo.addAttribute('lmax', str(self.lmax))
outputparams = deepcopy(self.params)
mScales = outputparams.pop('mScales')
pScales = outputparams.pop('pScales')
dScales = outputparams.pop('dScales')
for i in range(len(mScales)):
finfo.addAttribute(f'mScale1{i+2}', str(mScales[i]))
for i in range(len(pScales)):
finfo.addAttribute(f'pScale{i+1}', str(pScales[i]))
for i in range(len(dScales)):
finfo.addAttribute(f'dScale{i+1}', str(dScales[i]))

Q = outputparams['Q_local']
Q_global = convert_harm2cart(Q, self.lmax)

# <Atom>
for atom in range(self.n_atoms):
info = {'type': self.map_atomtype[atom]}
info.update({ktype:self.kStrings[ktype][atom] for ktype in ['kz', 'kx', 'ky']})
for i, key in enumerate(['c0', 'dX', 'dY', 'dZ', 'qXX', 'qXY', 'qXZ', 'qYY', 'qYZ', 'qZZ']):
info[key] = "%.8f" % Q_global[atom][i]
finfo.addElement('Atom', info)

# <Polarize>
for t in range(len(self.types)):
info = {
'type': self.types[t]
}
info.update({p: "%.8f" % self.params['pol'][t] for p in ['polarizabilityXX', 'polarizabilityYY', 'polarizabilityZZ']})
finfo.addElement('Polarize', info)

return finfo


app.forcefield.parsers["ADMPPmeForce"] = ADMPPmeGenerator.parseElement
Expand Down Expand Up @@ -1169,8 +1252,8 @@ def parseElement(element, hamiltonian):
"""
generator = HarmonicAngleJaxGenerator(hamiltonian)
hamiltonian.registerGenerator(generator)
for bondtype in element.findall("Angle"):
generator.registerAngleType(bondtype.attrib)
for angletype in element.findall("Angle"):
generator.registerAngleType(angletype.attrib)

def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):

Expand Down Expand Up @@ -1224,13 +1307,20 @@ def getJaxPotential(self):

def renderXML(self):
# generate xml force field file
pass
finfo = XMLNodeInfo("HarmonicAngleForce")
for i, type in enumerate(self.types):
t1, t2, t3 = type
ainfo = {'type1': t1, 'type2': t2, 'type3': t3, 'k': self.params['k'][i], 'angle': self.params['angle'][i]}
finfo.addElement('Angle', ainfo)

return finfo


# register all parsers
app.forcefield.parsers["HarmonicAngleForce"] = HarmonicAngleJaxGenerator.parseElement



def _matchImproper(data, torsion, generator):
type1 = data.atomType[data.atoms[torsion[0]]]
type2 = data.atomType[data.atoms[torsion[1]]]
Expand Down Expand Up @@ -1406,6 +1496,8 @@ def __init__(self, hamiltonian):
self.proper = []
self.improper = []
self.propersForAtomType = defaultdict(set)
self.n_proper = 0
self.n_improper = 0

def registerProperTorsion(self, parameters):
torsion = _parseTorsion(self.ff, parameters)
Expand Down Expand Up @@ -1437,7 +1529,7 @@ def parseElement(element, ff):

<PeriodicTorsionForce ordering="amber">
<Proper type1="" type2="c" type3="c" type4="" periodicity1="2" phase1="3.141592653589793" k1="1.2552"/>
<Proper type1="" type2="c" type3="c1" type4="" periodicity1="2" phase1="3.141592653589793" k1="0.0"/>
<Improper type1="" type2="c" type3="c1" type4="" periodicity1="2" phase1="3.141592653589793" k1="0.0"/>
</PeriodicTorsionForce>

"""
Expand Down Expand Up @@ -1773,8 +1865,53 @@ def getJaxPotential(self):
return self._jaxPotential

def renderXML(self):
params = self.params
# generate xml force field file
pass
finfo = XMLNodeInfo('PeriodicTorsionForce')
for i in range(len(self.proper)):
proper = self.proper[i]

finfo.addElement('Proper',
{'type1': proper.types1, 'type2': proper.types2,
'type3': proper.types3, 'type4': proper.types4,
'periodicity1': proper.periodicity[0],
'phase1': params['psi1_p'][i],
'k1': params['k1_p'][i],
'periodicity2': proper.periodicity[1],
'phase2': params['psi2_p'][i],
'k2': params['k2_p'][i],
'periodicity3': proper.periodicity[2],
'phase3': params['psi3_p'][i],
'k3': params['k3_p'][i],
'periodicity4': proper.periodicity[3],
'phase4': params['psi4_p'][i],
'k4': params['k4_p'][i],
}
)

for i in range(len(self.improper)):

improper = self.improper[i]

finfo.addElement('Improper',
{'type1': improper.types1, 'type2': improper.types2,
'type3': improper.types3, 'type4': improper.types4,
'periodicity1': improper.periodicity[0],
'phase1': params['psi1_i'][i],
'k1': params['k1_i'][i],
'periodicity2': proper.periodicity[1],
'phase2': params['psi2_i'][i],
'k2': params['k2_i'][i],
'periodicity3': proper.periodicity[2],
'phase3': params['psi3_i'][i],
'k3': params['k3_i'][i],
'periodicity4': proper.periodicity[3],
'phase4': params['psi4_i'][i],
'k4': params['k4_i'][i],
}
)

return finfo


app.forcefield.parsers[
Expand Down Expand Up @@ -1862,6 +1999,13 @@ def parseElement(element, ff):
generator.useAttributeFromResidue.append(eprm)
for atom in element.findall("Atom"):
generator.registerAtom(atom.attrib)

generator.n_atoms = len(element.findall("Atom"))

# jax it!
for k in generator.params.keys():
generator.params[k] = jnp.array(generator.params[k])
generator.types = np.array(generator.types)

def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):

Expand Down Expand Up @@ -1942,6 +2086,7 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
map_nbfix = []
# implement it later
map_nbfix = np.array(map_nbfix, dtype=int).reshape((-1, 2))


colv_map = build_covalent_map(data, 6)

Expand All @@ -1960,6 +2105,11 @@ def createForce(self, system, data, nonbondedMethod, nonbondedCutoff, args):
else:
r_switch = r_cut
ifSwitch = False

map_lj = jnp.array(map_lj)
map_nbfix = jnp.array(map_nbfix)
map_charge = jnp.array(map_charge)

ljforce = LennardJonesForce(
r_switch,
r_cut,
Expand Down Expand Up @@ -2012,12 +2162,23 @@ def getJaxPotential(self):
return self._jaxPotential

def renderXML(self):
pass

# <NonbondedForce>
finfo = XMLNodeInfo('NonbondedForce')
finfo.addAttribute('coulomb14scale', str(self.coulomb14scale))
finfo.addAttribute('lj14scale', str(self.lj14scale))

for atom in range(self.n_atoms):
info = {'type': self.types[atom], 'charge': self.params['charge'][atom], 'sigma': self.params['sigma'][atom], 'epsilon': self.params['epsilon'][atom]}
finfo.addElement('Atom', info)

return finfo


app.forcefield.parsers["NonbondedForce"] = NonbondJaxGenerator.parseElement



class Hamiltonian(app.forcefield.ForceField):
def __init__(self, *xmlnames):
super().__init__(*xmlnames)
Expand Down
1 change: 1 addition & 0 deletions dmff/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
if PRECISION == 'double':
config.update("jax_enable_x64", True)

__all__ = ['PRECISION', 'DO_JIT']
35 changes: 34 additions & 1 deletion docs/dev_guide/arch.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ class SimpleJAXGenerator:
return self._jaxPotential

def renderXML(self):
render_xml_forcefield_from_params
# render_xml_forcefield_from_params


app.parsers["SimpleJAXForce"] = SimpleJAXGenerator.parseElement
Expand Down Expand Up @@ -291,6 +291,39 @@ class HarmonicBondJaxGenerator:
app.forcefield.parsers["HarmonicBondForce"] = HarmonicBondJaxGenerator.parseElement
```

After the calculation and optimization, we need to save the optimized parameters as XML format files for the next calculation. This serialization process is implemented through the `renderXML` method. At the beginning of the `api.py` file, we provide nested helper classes called `XMLNodeInfo` and `XMLElementInfo`. In the XML file, a `<HarmonicJaxBondForce>` and its close tag is represented by XMLNodeInfo and the content element is controlled by `XMLElementInfo`

```
<HarmonicJaxBondForce>
<Bond type1="ow" type2="hw" length="0.0957" k="462750.4"/>
<Bond type1="hw" type2="hw" length="0.1513" k="462750.4"/>
</HarmonicJaxBondForce>
```

When we want to serialize optimized parameters from the generator to a new XML file, we first initialize a `XMLNodeInfo(name:str)` class with the potential name

```python
finfo = XMLNodeInfo("HarmonicBondForce")
```
If necessary, you can add attributes to this tag using the `addAttribute(name:str, value:str)` method. Then we add the inner `<Bond>` tag by invoke `finfo.addElement(name:str, attrib:dict)` method. Here is an example to render `<HarmonicBondForce>`

```
def renderXML(self):
# generate xml force field file
finfo = XMLNodeInfo("HarmonicBondForce") # <HarmonicBondForce> and <\HarmonicBondForce>
for ntype in range(len(self.types)):
binfo = {}
k1, v1 = self.typetexts[ntype][0]
k2, v2 = self.typetexts[ntype][1]
binfo[k1] = v1
binfo[k2] = v2
for key in self.params.keys():
binfo[key] = "%.8f"%self.params[key][ntype]
finfo.addElement("Bond", binfo) # <Bond binfo.key=binfo.value ...>
return finfo
```


## How Backend Works

### Force Class
Expand Down
Loading