Skip to content

Commit

Permalink
Merge pull request #581 from amcadmus/devel
Browse files Browse the repository at this point in the history
Add type embedding
  • Loading branch information
amcadmus committed May 1, 2021
2 parents 0d0cb77 + 73482ac commit 2ed422f
Show file tree
Hide file tree
Showing 17 changed files with 1,368 additions and 132 deletions.
208 changes: 147 additions & 61 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from deepmd.env import default_tf_session_config
from deepmd.utils.network import embedding_net
from deepmd.utils.tabulate import DeepTabulate

from deepmd.utils.type_embed import embed_atom_type

class DescrptSeA ():
@docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys()))
Expand Down Expand Up @@ -101,6 +101,11 @@ def __init__ (self,
self.davg = None
self.compress = False
self.place_holders = {}
nei_type = np.array([])
for ii in range(self.ntypes):
nei_type = np.append(nei_type, ii * np.ones(self.sel_a[ii])) # like a mask
self.nei_type = tf.constant(nei_type, dtype = tf.int32)

avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
sub_graph = tf.Graph()
Expand Down Expand Up @@ -214,7 +219,7 @@ def compute_input_stats (self,
sumr2 = np.sum(sumr2, axis = 0)
suma2 = np.sum(suma2, axis = 0)
for type_i in range(self.ntypes) :
davgunit = [sumr[type_i]/sumn[type_i], 0, 0, 0]
davgunit = [sumr[type_i]/(sumn[type_i]+1e-15), 0, 0, 0]
dstdunit = [self._compute_std(sumr2[type_i], sumr[type_i], sumn[type_i]),
self._compute_std(suma2[type_i], suma[type_i], sumn[type_i]),
self._compute_std(suma2[type_i], suma[type_i], sumn[type_i]),
Expand Down Expand Up @@ -440,11 +445,15 @@ def _pass_filter(self,
reuse = None,
suffix = '',
trainable = True) :
if input_dict is not None:
type_embedding = input_dict.get('type_embedding', None)
else:
type_embedding = None
start_index = 0
inputs = tf.reshape(inputs, [-1, self.ndescrpt * natoms[0]])
output = []
output_qmat = []
if not self.type_one_side:
if not self.type_one_side and type_embedding is None:
for type_i in range(self.ntypes):
inputs_i = tf.slice (inputs,
[ 0, start_index* self.ndescrpt],
Expand All @@ -460,7 +469,7 @@ def _pass_filter(self,
inputs_i = inputs
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
type_i = -1
layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn, type_embedding=type_embedding)
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()])
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3])
output.append(layer)
Expand Down Expand Up @@ -516,75 +525,152 @@ def _compute_dstats_sys_smth (self,


def _compute_std (self,sumv2, sumv, sumn) :
if sumn == 0:
return 1e-2
val = np.sqrt(sumv2/sumn - np.multiply(sumv/sumn, sumv/sumn))
if np.abs(val) < 1e-2:
val = 1e-2
return val


def _filter(self,
inputs,
type_input,
natoms,
activation_fn=tf.nn.tanh,
stddev=1.0,
bavg=0.0,
name='linear',
reuse=None,
seed=None,
trainable = True):
def _concat_type_embedding(
self,
xyz_scatter,
nframes,
natoms,
type_embedding,
):
te_out_dim = type_embedding.get_shape().as_list()[-1]
nei_embed = tf.nn.embedding_lookup(type_embedding,tf.cast(self.nei_type,dtype=tf.int32)) #nnei*nchnl
nei_embed = tf.tile(nei_embed,(nframes*natoms[0],1))
nei_embed = tf.reshape(nei_embed,[-1,te_out_dim])
embedding_input = tf.concat([xyz_scatter,nei_embed],1)
if not self.type_one_side:
atm_embed = embed_atom_type(self.ntypes, natoms, type_embedding)
atm_embed = tf.tile(atm_embed,(1,self.nnei))
atm_embed = tf.reshape(atm_embed,[-1,te_out_dim])
embedding_input = tf.concat([embedding_input,atm_embed],1)
return embedding_input


def _filter_lower(
self,
start_index,
incrs_index,
inputs,
nframes,
natoms,
type_embedding=None,
is_exclude = False,
activation_fn = None,
bavg = 0.0,
stddev = 1.0,
seed = None,
trainable = True,
suffix = '',
):
"""
input env matrix, returns R.G
"""
outputs_size = [1] + self.filter_neuron
# cut-out inputs
# with natom x (nei_type_i x 4)
inputs_i = tf.slice (inputs,
[ 0, start_index* 4],
[-1, incrs_index* 4] )
shape_i = inputs_i.get_shape().as_list()
# with (natom x nei_type_i) x 4
inputs_reshape = tf.reshape(inputs_i, [-1, 4])
# with (natom x nei_type_i) x 1
xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0,0],[-1,1]),[-1,1])
if type_embedding is not None:
type_embedding = tf.cast(type_embedding, self.filter_precision)
xyz_scatter = self._concat_type_embedding(
xyz_scatter, nframes, natoms, type_embedding)
if self.compress:
raise RuntimeError('compression of type embedded descriptor is not supported at the moment')
# with (natom x nei_type_i) x out_size
if self.compress and (not is_exclude):
info = [self.lower, self.upper, self.upper * self.table_config[0], self.table_config[1], self.table_config[2], self.table_config[3]]
if self.type_one_side:
net = 'filter_-1_net_' + str(type_i)
else:
net = 'filter_' + str(type_input) + '_net_' + str(type_i)
return op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
else:
if (not is_exclude):
xyz_scatter = embedding_net(
xyz_scatter,
self.filter_neuron,
self.filter_precision,
activation_fn = activation_fn,
resnet_dt = self.filter_resnet_dt,
name_suffix = suffix,
stddev = stddev,
bavg = bavg,
seed = seed,
trainable = trainable)
else:
w = tf.zeros((outputs_size[0], outputs_size[-1]), dtype=GLOBAL_TF_FLOAT_PRECISION)
xyz_scatter = tf.matmul(xyz_scatter, w)
# natom x nei_type_i x out_size
xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1]))
return tf.matmul(tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True)


def _filter(
self,
inputs,
type_input,
natoms,
type_embedding = None,
activation_fn=tf.nn.tanh,
stddev=1.0,
bavg=0.0,
name='linear',
reuse=None,
seed=None,
trainable = True):
nframes = tf.shape(tf.reshape(inputs, [-1, natoms[0], self.ndescrpt]))[0]
# natom x (nei x 4)
shape = inputs.get_shape().as_list()
outputs_size = [1] + self.filter_neuron
outputs_size_2 = self.n_axis_neuron
with tf.variable_scope(name, reuse=reuse):
start_index = 0
xyz_scatter_total = []
for type_i in range(self.ntypes):
# cut-out inputs
# with natom x (nei_type_i x 4)
inputs_i = tf.slice (inputs,
[ 0, start_index* 4],
[-1, self.sel_a[type_i]* 4] )
start_index += self.sel_a[type_i]
shape_i = inputs_i.get_shape().as_list()
# with (natom x nei_type_i) x 4
inputs_reshape = tf.reshape(inputs_i, [-1, 4])
# with (natom x nei_type_i) x 1
xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0,0],[-1,1]),[-1,1])
# with (natom x nei_type_i) x out_size
if self.compress and (type_input, type_i) not in self.exclude_types:
info = [self.lower, self.upper, self.upper * self.table_config[0], self.table_config[1], self.table_config[2], self.table_config[3]]
if self.type_one_side:
net = 'filter_-1_net_' + str(type_i)
else:
net = 'filter_' + str(type_input) + '_net_' + str(type_i)
if type_i == 0:
xyz_scatter_1 = op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
else:
xyz_scatter_1 += op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
else:
if (type_input, type_i) not in self.exclude_types:
xyz_scatter = embedding_net(xyz_scatter,
self.filter_neuron,
self.filter_precision,
activation_fn = activation_fn,
resnet_dt = self.filter_resnet_dt,
name_suffix = "_"+str(type_i),
stddev = stddev,
bavg = bavg,
seed = seed,
trainable = trainable)
else:
w = tf.zeros((outputs_size[0], outputs_size[-1]), dtype=GLOBAL_TF_FLOAT_PRECISION)
xyz_scatter = tf.matmul(xyz_scatter, w)
# natom x nei_type_i x out_size
xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1]))
# xyz_scatter_total.append(xyz_scatter)
if type_i == 0 :
xyz_scatter_1 = tf.matmul(tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True)
else :
xyz_scatter_1 += tf.matmul(tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True)
if type_embedding is None:
for type_i in range(self.ntypes):
ret = self._filter_lower(
start_index, self.sel_a[type_i],
inputs,
nframes,
natoms,
type_embedding = type_embedding,
is_exclude = (type_input, type_i) in self.exclude_types,
activation_fn = activation_fn,
stddev = stddev,
bavg = bavg,
seed = seed,
trainable = trainable,
suffix = "_"+str(type_i))
if type_i == 0:
xyz_scatter_1 = ret
else:
xyz_scatter_1+= ret
start_index += self.sel_a[type_i]
else :
xyz_scatter_1 = self._filter_lower(
start_index, np.cumsum(self.sel_a)[-1],
inputs,
nframes,
natoms,
type_embedding = type_embedding,
is_exclude = False,
activation_fn = activation_fn,
stddev = stddev,
bavg = bavg,
seed = seed,
trainable = trainable)
# natom x nei x outputs_size
# xyz_scatter = tf.concat(xyz_scatter_total, axis=1)
# natom x nei x 4
Expand Down

0 comments on commit 2ed422f

Please sign in to comment.