Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 6 additions & 45 deletions dmff/admp/mbpol_intra.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import sys
import numpy as np
import jax.numpy as jnp
Expand All @@ -10,8 +9,6 @@
from dmff.admp.parser import *
from jax import vmap
import time
#from admp.multipole import convert_cart2harm
#from jax_md import partition, space

#const
f5z = 0.999677885
Expand Down Expand Up @@ -467,12 +464,14 @@ def onebodyenergy(positions, box):

@vmap
@jit_condition(static_argnums={})
def onebody_kernel(x1, x2, x3, Va, Vb, efac):
def onebody_kernel(x1, x2, x3, Va, Vb, efac):
a = jnp.arange(-1,15)
a = a.at[0].set(0)
const = jnp.array([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
CONST = jnp.array([const,const,const])
list1 = jnp.array([x1**i for i in range(-1, 15)])
list2 = jnp.array([x2**i for i in range(-1, 15)])
list3 = jnp.array([x3**i for i in range(-1, 15)])
list1 = jnp.array([x1**i for i in a])
list2 = jnp.array([x2**i for i in a])
list3 = jnp.array([x3**i for i in a])
fmat = jnp.array([list1, list2, list3])
fmat *= CONST
F1 = jnp.sum(fmat[0].T * matrix1, axis=1) # fmat[0][inI] 1*245
Expand All @@ -489,41 +488,3 @@ def onebody_kernel(x1, x2, x3, Va, Vb, efac):
e1 *= cal2joule # conver cal 2 j
return e1


def validation(pdb):
xml = 'mpidwater.xml'
pdbinfo = read_pdb(pdb)
serials = pdbinfo['serials']
names = pdbinfo['names']
resNames = pdbinfo['resNames']
resSeqs = pdbinfo['resSeqs']
positions = pdbinfo['positions']
box = pdbinfo['box'] # a, b, c, α, β, γ
charges = pdbinfo['charges']
positions = jnp.asarray(positions)
lx, ly, lz, _, _, _ = box
box = jnp.eye(3)*jnp.array([lx, ly, lz])

mScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
pScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])
dScales = jnp.array([0.0, 0.0, 0.0, 1.0, 1.0])

rc = 4 # in Angstrom
ethresh = 1e-4

n_atoms = len(serials)

# compute intra


grad_E1 = value_and_grad(onebodyenergy,argnums=(0))
ene, force = grad_E1(positions, box)
print(ene,force)
return


# below is the validation code
if __name__ == '__main__':
validation(sys.argv[1])


172 changes: 104 additions & 68 deletions dmff/sgnn/gnn.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,54 @@
#!/usr/bin/env python
import sys
import numpy as np
import jax.numpy as jnp
import jax.lax as lax
from jax import vmap, value_and_grad
import jax.nn.initializers
from dmff.utils import jit_condition
from dmff.sgnn.graph import MAX_VALENCE
from dmff.sgnn.graph import TopGraph, from_pdb
import pickle
import re
import sys
from collections import OrderedDict
from functools import partial

class MolGNN:

def __init__(self, G, n_layers=(3, 2), sizes=[(40, 20, 20), (20, 10)], nn=1,
sigma=162.13039087945623, mu=117.41975505778706, seed=12345):
import jax.lax as lax
import jax.nn.initializers
import jax.numpy as jnp
import numpy as np
from dmff.sgnn.graph import MAX_VALENCE, TopGraph, from_pdb
from dmff.utils import jit_condition
from jax import value_and_grad, vmap


class MolGNNForce:

def __init__(self,
G,
n_layers=(3, 2),
sizes=[(40, 20, 20), (20, 10)],
nn=1,
sigma=162.13039087945623,
mu=117.41975505778706,
seed=12345):
""" Constructor for MolGNNForce

Parameters
----------
G: TopGraph object
The topological graph object, created using dmff.sgnn.graph.TopGraph
n_layers: int tuple, optional
Number of hidden layers before and after message passing
default = (3, 2)
sizes: [tuple, tuple], optional
sizes (numbers of hidden neurons) of the network before and after message passing
default = [(40, 20, 20), (20, 10)]
nn: int, optional
size of the subgraphs, i.e., how many neighbors to include around the central bond
default = 1
sigma: float, optional
final scaling factor of the energy.
default = 162.13039087945623
mu: float, optional
a constant shift
the final total energy would be ${(E_{NN} + \mu) * \sigma}
seed: int: optional
random seed used in network initialization
default = 12345

"""
self.nn = nn
self.G = G
self.G.get_all_subgraphs(nn, typify=True)
Expand All @@ -29,17 +61,18 @@ def __init__(self, G, n_layers=(3, 2), sizes=[(40, 20, 20), (20, 10)], nn=1,
dim_in = G.n_features
initializer = jax.nn.initializers.he_uniform()
for i_nn, n_layers in enumerate(n_layers):
nn_name = 'fc%d'%i_nn
nn_name = 'fc%d' % i_nn
params[nn_name + '.weight'] = []
params[nn_name + '.bias'] = []
for i_layer in range(n_layers):
layer_name = nn_name + '.' + '%d'%i_layer
layer_name = nn_name + '.' + '%d' % i_layer
dim_out = sizes[i_nn][i_layer]
# params[nn_name+'.weight'].append(jnp.array(np.random.random((dim_out, dim_in))))
# params[nn_name+'.bias'].append(jnp.array(np.random.random(dim_out)))
key, subkey = jax.random.split(key)
params[nn_name+'.weight'].append(initializer(subkey, (dim_out, dim_in)))
params[nn_name+'.bias'].append(jnp.zeros(dim_out))
params[nn_name + '.weight'].append(
initializer(subkey, (dim_out, dim_in)))
params[nn_name + '.bias'].append(jnp.zeros(dim_out))
dim_in = dim_out
key, subkey = jax.random.split(key)
params['fc_final.weight'] = jnp.array(initializer(subkey, (1, dim_in)))
Expand All @@ -60,30 +93,34 @@ def forward(positions, box, params, nn):
def fc0(f_in, params):
f = f_in
for i in range(self.n_layers[0]):
f = jnp.tanh(params['fc0.weight'][i].dot(f) + params['fc0.bias'][i])
f = jnp.tanh(params['fc0.weight'][i].dot(f) +
params['fc0.bias'][i])
return f

@jit_condition(static_argnums=())
@partial(vmap, in_axes=(0, None), out_axes=(0))
def fc1(f_in, params):
f = f_in
for i in range(self.n_layers[1]):
f = jnp.tanh(params['fc1.weight'][i].dot(f) + params['fc1.bias'][i])
f = jnp.tanh(params['fc1.weight'][i].dot(f) +
params['fc1.bias'][i])
return f

@jit_condition(static_argnums=())
@partial(vmap, in_axes=(0, None), out_axes=(0))
def fc_final(f_in, params):
return params['fc_final.weight'].dot(f_in) + params['fc_final.bias']
return params['fc_final.weight'].dot(
f_in) + params['fc_final.bias']

# @jit_condition(static_argnums=(3))
@partial(vmap, in_axes=(0, 0, None, None), out_axes=(0))
def message_pass(f_in, nb_connect, w, nn):
if nn == 0:
return f_in[0]
elif nn == 1:
nb_connect0 = nb_connect[0:MAX_VALENCE-1]
nb_connect1 = nb_connect[MAX_VALENCE-1:2*(MAX_VALENCE-1)]
nb_connect0 = nb_connect[0:MAX_VALENCE - 1]
nb_connect1 = nb_connect[MAX_VALENCE - 1:2 *
(MAX_VALENCE - 1)]
nb0 = jnp.sum(nb_connect0)
nb1 = jnp.sum(nb_connect1)
f = f_in[0] * (1 - jnp.heaviside(nb0, 0)*w - jnp.heaviside(nb1, 0)*w) + \
Expand All @@ -92,86 +129,85 @@ def message_pass(f_in, nb_connect, w, nn):
return f

features = fc0(features, params)
features = message_pass(features, self.G.nb_connect, params['w'], self.G.nn)
features = message_pass(features, self.G.nb_connect, params['w'],
self.G.nn)
features = fc1(features, params)
energies = fc_final(features, params)

return self.G.weights.dot(energies)[0] * self.sigma + self.mu

self.forward = partial(forward, nn=self.G.nn)
self.batch_forward = vmap(self.forward, in_axes=(0, 0, None), out_axes=(0))
self.batch_forward = vmap(self.forward,
in_axes=(0, 0, None),
out_axes=(0))

return
# provide the get_energy function, to be consistent with the other parts of DMFF
self.get_energy = self.forward

return

def load_params(self, ifn):
""" Load the network parameters from saved file

Parameters
----------
ifn: string
the input file name

"""
with open(ifn, 'rb') as ifile:
params = pickle.load(ifile)
for k in params.keys():
params[k] = jnp.array(params[k])
# transform format
keys = list(params.keys())
for i_nn in [0, 1]:
nn_name = 'fc%d'%i_nn
nn_name = 'fc%d' % i_nn
keys_weight = []
keys_bias = []
for k in keys:
if re.search(nn_name + '.[0-9]+.weight', k) is not None:
keys_weight.append(k)
elif re.search(nn_name + '.[0-9]+.bias', k) is not None:
keys_bias.append(k)
if len(keys_weight) != self.n_layers[i_nn] or len(keys_bias) != self.n_layers[i_nn]:
sys.exit('Error while loading GNN params, inconsistent inputs with the GNN structure, check your input!')
params['%s.weight'%nn_name] = []
params['%s.bias'%nn_name] = []
if len(keys_weight) != self.n_layers[i_nn] or len(
keys_bias) != self.n_layers[i_nn]:
sys.exit(
'Error while loading GNN params, inconsistent inputs with the GNN structure, check your input!'
)
params['%s.weight' % nn_name] = []
params['%s.bias' % nn_name] = []
for i_layer in range(self.n_layers[i_nn]):
k_w = '%s.%d.weight'%(nn_name, i_layer)
k_b = '%s.%d.bias'%(nn_name, i_layer)
params['%s.weight'%nn_name].append(params.pop(k_w, None))
params['%s.bias'%nn_name].append(params.pop(k_b, None))
k_w = '%s.%d.weight' % (nn_name, i_layer)
k_b = '%s.%d.bias' % (nn_name, i_layer)
params['%s.weight' % nn_name].append(params.pop(k_w, None))
params['%s.bias' % nn_name].append(params.pop(k_b, None))
# params[nn_name]
self.params = params
return

return

def save_params(self, ofn):
""" Save the network parameters to a pickle file

Parameters
----------
ofn: string
the output file name

"""
# transform format
params = {}
params['w'] = self.params['w']
params['fc_final.weight'] = self.params['fc_final.weight']
params['fc_final.bias'] = self.params['fc_final.bias']
for i_nn in range(2):
nn_name = 'fc%d'%i_nn
nn_name = 'fc%d' % i_nn
for i_layer in range(self.n_layers[i_nn]):
params[nn_name+'.%d.weight'%i_layer] = self.params[nn_name+'.weight'][i_layer]
params[nn_name+'.%d.bias'%i_layer] = self.params[nn_name+'.bias'][i_layer]
params[nn_name + '.%d.weight' %
i_layer] = self.params[nn_name + '.weight'][i_layer]
params[nn_name +
'.%d.bias' % i_layer] = self.params[nn_name +
'.bias'][i_layer]
with open(ofn, 'wb') as ofile:
pickle.dump(params, ofile)
return


def validation():
# params = load_params('benchmark/model1.pickle')
G = from_pdb('benchmark/peg4.pdb')
model = MolGNN(G, nn=1)
model.load_params('benchmark/model1.pickle')
E = model.forward(G.positions, G.box, model.params)

with open('benchmark/set009_remove_nb2.pickle', 'rb') as ifile:
data = pickle.load(ifile)

# pos = jnp.array(data['positions'][0:100])
# box = jnp.tile(jnp.eye(3) * 50, (100, 1, 1))
pos = jnp.array(data['positions'][0])
box = jnp.eye(3) * 50

# energies = model.batch_forward(pos, box, model.params)
E, F = value_and_grad(model.forward, argnums=(0))(pos, box, model.params)
F = -F
print(E)
print(F)


if __name__ == '__main__':
validation()

Loading