Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ut.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ jobs:
pytest -vs tests/test_common/test_*
pytest -vs tests/test_admp/test_*
pytest -vs tests/test_utils.py
pytest -vs tests/test_mbar/test_*
pytest -vs tests/test_mbar/test_*
pytest -vs tests/test_sgnn/test_*
2 changes: 1 addition & 1 deletion dmff/api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def matchTemplate(graph, template):
if graph.number_of_nodes() != template.number_of_nodes():
print("Node with different number of nodes.")
# print("Node with different number of nodes.")
return False, {}, {}

def match_func(n1, n2):
Expand Down
3 changes: 2 additions & 1 deletion dmff/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .classical import *
from .admp import *
from .admp import *
from .ml import *
34 changes: 24 additions & 10 deletions dmff/generators/admp.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,15 +822,28 @@ def __init__(self, ffinfo: dict, paramset: ParamSet):
kzs.append(kz)
# record multipoles
c0.append(float(attribs["c0"]))
dX.append(float(attribs["dX"]))
dY.append(float(attribs["dY"]))
dZ.append(float(attribs["dZ"]))
qXX.append(float(attribs["qXX"]))
qYY.append(float(attribs["qYY"]))
qZZ.append(float(attribs["qZZ"]))
qXY.append(float(attribs["qXY"]))
qXZ.append(float(attribs["qXZ"]))
qYZ.append(float(attribs["qYZ"]))
if self.lmax >= 1:
dX.append(float(attribs["dX"]))
dY.append(float(attribs["dY"]))
dZ.append(float(attribs["dZ"]))
else:
dX.append(0.0)
dY.append(0.0)
dZ.append(0.0)
if self.lmax >= 2:
qXX.append(float(attribs["qXX"]))
qYY.append(float(attribs["qYY"]))
qZZ.append(float(attribs["qZZ"]))
qXY.append(float(attribs["qXY"]))
qXZ.append(float(attribs["qXZ"]))
qYZ.append(float(attribs["qYZ"]))
else:
qXX.append(0.0)
qYY.append(0.0)
qZZ.append(0.0)
qXY.append(0.0)
qXZ.append(0.0)
qYZ.append(0.0)
mask = 1.0
if "mask" in attribs and attribs["mask"].upper() == "TRUE":
mask = 0.0
Expand Down Expand Up @@ -1146,6 +1159,7 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutof
pme_force = ADMPPmeForce(box, axis_types, axis_indices, rc,
self.ethresh, self.lmax, self.lpol, lpme,
self.step_pol)
self.pme_force = pme_force

def potential_fn(positions, box, pairs, params):
positions = positions * 10
Expand Down Expand Up @@ -1181,4 +1195,4 @@ def getMetaData(self):
return self._meta


_DMFFGenerators["ADMPPmeForce"] = ADMPPmeGenerator
_DMFFGenerators["ADMPPmeForce"] = ADMPPmeGenerator
74 changes: 74 additions & 0 deletions dmff/generators/ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from ..api.topology import DMFFTopology
from ..api.paramset import ParamSet
from ..api.hamiltonian import _DMFFGenerators
from ..utils import DMFFException, isinstance_jnp
from ..utils import jit_condition
import numpy as np
import jax
import jax.numpy as jnp
import openmm.app as app
import openmm.unit as unit
import pickle

from ..sgnn.graph import MAX_VALENCE, TopGraph, from_pdb
from ..sgnn.gnn import MolGNNForce, prm_transform_f2i


class SGNNGenerator:
def __init__(self, ffinfo: dict, paramset: ParamSet):

self.name = "SGNNForce"
self.ffinfo = ffinfo
paramset.addField(self.name)
self.key_type = None

self.file = self.ffinfo["Forces"][self.name]["meta"]["file"]
self.nn = int(self.ffinfo["Forces"][self.name]["meta"]["nn"])
self.pdb = self.ffinfo["Forces"][self.name]["meta"]["pdb"]

# load ML potential parameters
with open(self.file, 'rb') as ifile:
params = pickle.load(ifile)

# convert to jnp array
for k in params:
params[k] = jnp.array(params[k])
# set mask to all true
paramset.addParameter(params[k], k, field=self.name, mask=jnp.ones(params[k].shape))

# mask = jax.tree_util.tree_map(lambda x: jnp.ones(x.shape), params)
# paramset.addParameter(params, "params", field=self.name, mask=mask)


def getName(self) -> str:
return self.name

def overwrite(self, paramset):
# do not use xml to handle ML potentials
# for ML potentials, xml only documents param file path
# so for ML potentials, overwrite function overwrites the file directly
with open(self.file, 'wb') as ofile:
pickle.dump(paramset[self.name], ofile)
return

def createPotential(self, topdata: DMFFTopology, nonbondedMethod, nonbondedCutoff, **kwargs):
self.G = from_pdb(self.pdb)
n_atoms = topdata.getNumAtoms()
self.model = MolGNNForce(self.G, nn=self.nn)
n_layers = self.model.n_layers
def potential_fn(positions, box, pairs, params):
# convert unit to angstrom
positions = positions * 10
box = box * 10
prms = prm_transform_f2i(params[self.name], n_layers)
return self.model.get_energy(positions, box, prms)

self._jaxPotential = potential_fn
return potential_fn

def getJaxPotential(self):
return self._jaxPotential


_DMFFGenerators["SGNNForce"] = SGNNGenerator

76 changes: 40 additions & 36 deletions dmff/sgnn/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,39 @@
from jax import value_and_grad, vmap


def prm_transform_f2i(params, n_layers):
p = {}
for k in params:
p[k] = jnp.array(params[k])
for i_nn in [0, 1]:
nn_name = 'fc%d' % i_nn
p['%s.weight' % nn_name] = []
p['%s.bias' % nn_name] = []
for i_layer in range(n_layers[i_nn]):
k_w = '%s.%d.weight' % (nn_name, i_layer)
k_b = '%s.%d.bias' % (nn_name, i_layer)
p['%s.weight' % nn_name].append(p.pop(k_w, None))
p['%s.bias' % nn_name].append(p.pop(k_b, None))
return p


def prm_transform_i2f(params, n_layers):
# transform format
p = {}
p['w'] = params['w']
p['fc_final.weight'] = params['fc_final.weight']
p['fc_final.bias'] = params['fc_final.bias']
for i_nn in range(2):
nn_name = 'fc%d' % i_nn
for i_layer in range(n_layers[i_nn]):
p[nn_name + '.%d.weight' %
i_layer] = params[nn_name + '.weight'][i_layer]
p[nn_name +
'.%d.bias' % i_layer] = params[nn_name +
'.bias'][i_layer]
return p


class MolGNNForce:

def __init__(self,
Expand Down Expand Up @@ -146,6 +179,7 @@ def message_pass(f_in, nb_connect, w, nn):

return


def load_params(self, ifn):
""" Load the network parameters from saved file

Expand All @@ -160,32 +194,12 @@ def load_params(self, ifn):
for k in params.keys():
params[k] = jnp.array(params[k])
# transform format
keys = list(params.keys())
for i_nn in [0, 1]:
nn_name = 'fc%d' % i_nn
keys_weight = []
keys_bias = []
for k in keys:
if re.search(nn_name + '.[0-9]+.weight', k) is not None:
keys_weight.append(k)
elif re.search(nn_name + '.[0-9]+.bias', k) is not None:
keys_bias.append(k)
if len(keys_weight) != self.n_layers[i_nn] or len(
keys_bias) != self.n_layers[i_nn]:
sys.exit(
'Error while loading GNN params, inconsistent inputs with the GNN structure, check your input!'
)
params['%s.weight' % nn_name] = []
params['%s.bias' % nn_name] = []
for i_layer in range(self.n_layers[i_nn]):
k_w = '%s.%d.weight' % (nn_name, i_layer)
k_b = '%s.%d.bias' % (nn_name, i_layer)
params['%s.weight' % nn_name].append(params.pop(k_w, None))
params['%s.bias' % nn_name].append(params.pop(k_b, None))
# params[nn_name]
self.params = params
self.params = prm_transform_f2i(params, self.n_layers)
return




def save_params(self, ofn):
""" Save the network parameters to a pickle file

Expand All @@ -196,18 +210,8 @@ def save_params(self, ofn):

"""
# transform format
params = {}
params['w'] = self.params['w']
params['fc_final.weight'] = self.params['fc_final.weight']
params['fc_final.bias'] = self.params['fc_final.bias']
for i_nn in range(2):
nn_name = 'fc%d' % i_nn
for i_layer in range(self.n_layers[i_nn]):
params[nn_name + '.%d.weight' %
i_layer] = self.params[nn_name + '.weight'][i_layer]
params[nn_name +
'.%d.bias' % i_layer] = self.params[nn_name +
'.bias'][i_layer]
params = prm_transform_i2f(self.params, self.n_layers)
with open(ofn, 'wb') as ofile:
pickle.dump(params, ofile)
return

14 changes: 14 additions & 0 deletions dmff/sgnn/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,20 @@ def from_pdb(pdb):
return TopGraph(list_atom_elems, bonds, positions=positions, box=box)


# def from_dmff_top(topdata):
# '''
# Build the sGNN TopGraph object from a DMFFTopology object

# Parameters
# ----------
# topdata: DMFFTopology data
# '''
# list_atom_elems = np.array([a.element for a in topdata.atoms()])
# bonds = np.array([np.sort([b.atom1.index, b.atom2.index]) for b in topdata.bonds()])
# n_atoms = len(list_atom_elems)
# return TopGraph(list_atom_elems, bonds, positions=jnp.zeros((n_atoms, 3)), box=jnp.eye(3)*10)


def validation():
G = from_pdb('peg4.pdb')
nn = 1
Expand Down
3 changes: 2 additions & 1 deletion examples/classical/test_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,5 @@ def getEnergyDecomposition(context, forcegroups):
print("Nonbonded:", nbE(positions, box, pairs, params))

etotal = pot.getPotentialFunc()
print("Total:", etotal(positions, box, pairs, params))
print("Total:", etotal(positions, box, pairs, params))

Binary file removed examples/sgnn/model1.pickle
Binary file not shown.
1 change: 1 addition & 0 deletions examples/sgnn/model1.pickle
48 changes: 48 additions & 0 deletions examples/sgnn/peg.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<ForceField>
<AtomTypes>
<Type class="CT" element="C" mass="12.0107" name="1" />
<Type class="HC" element="H" mass="1.00784" name="2" />
<Type class="OS" element="O" mass="15.999" name="3" />
<Type class="CT" element="C" mass="12.0107" name="4" />
<Type class="HC" element="H" mass="1.00784" name="5" />
</AtomTypes>
<Residues>
<Residue name="TER">
<Atom name="C00" type="1" />
<Atom name="H01" type="2" />
<Atom name="H02" type="2" />
<Atom name="O03" type="3" />
<Atom name="C04" type="4" />
<Atom name="H05" type="5" />
<Atom name="H06" type="5" />
<Atom name="H07" type="5" />
<Bond from="0" to="1" />
<Bond from="0" to="2" />
<Bond from="0" to="3" />
<Bond from="3" to="4" />
<Bond from="4" to="5" />
<Bond from="4" to="6" />
<Bond from="4" to="7" />
<ExternalBond atomName="C00" />
</Residue>
<Residue name="INT">
<Atom name="C00" type="1" />
<Atom name="H01" type="2" />
<Atom name="H02" type="2" />
<Atom name="O03" type="3" />
<Atom name="C04" type="1" />
<Atom name="H05" type="2" />
<Atom name="H06" type="2" />
<Bond from="0" to="1" />
<Bond from="0" to="2" />
<Bond from="0" to="3" />
<Bond from="3" to="4" />
<Bond from="4" to="5" />
<Bond from="4" to="6" />
<ExternalBond atomName="C00" />
<ExternalBond atomName="C04" />
</Residue>
</Residues>
<SGNNForce file="model1.pickle" pdb="peg4.pdb" nn="1"/>
</ForceField>

Loading