-
Notifications
You must be signed in to change notification settings - Fork 492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finish model compression for se_r descriptor! #1361
Changes from 6 commits
ae433bb
c2d2723
5ba1976
aefa443
9d45e3e
5f50677
d9f508d
e336e3f
18d6266
079ae29
c5a7888
e613be5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -82,20 +82,40 @@ def __init__(self, | |
self.sub_graph, self.sub_graph_def = self._load_sub_graph() | ||
self.sub_sess = tf.Session(graph = self.sub_graph) | ||
|
||
try: | ||
self.sel_a = self.graph.get_operation_by_name('ProdEnvMatA').get_attr('sel_a') | ||
self.prod_env_mat_op = self.graph.get_operation_by_name ('ProdEnvMatA') | ||
except Exception: | ||
self.sel_a = self.graph.get_operation_by_name('DescrptSeA').get_attr('sel_a') | ||
self.prod_env_mat_op = self.graph.get_operation_by_name ('DescrptSeA') | ||
if isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): | ||
try: | ||
self.sel_a = self.graph.get_operation_by_name('ProdEnvMatR').get_attr('sel') | ||
self.prod_env_mat_op = self.graph.get_operation_by_name ('ProdEnvMatR') | ||
except KeyError: | ||
self.sel_a = self.graph.get_operation_by_name('DescrptSeR').get_attr('sel') | ||
self.prod_env_mat_op = self.graph.get_operation_by_name ('DescrptSeR') | ||
elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeA): | ||
try: | ||
self.sel_a = self.graph.get_operation_by_name('ProdEnvMatA').get_attr('sel_a') | ||
self.prod_env_mat_op = self.graph.get_operation_by_name ('ProdEnvMatA') | ||
except KeyError: | ||
self.sel_a = self.graph.get_operation_by_name('DescrptSeA').get_attr('sel_a') | ||
self.prod_env_mat_op = self.graph.get_operation_by_name ('DescrptSeA') | ||
elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeT): | ||
try: | ||
self.sel_a = self.graph.get_operation_by_name('ProdEnvMatA').get_attr('sel_a') | ||
self.prod_env_mat_op = self.graph.get_operation_by_name ('ProdEnvMatA') | ||
except KeyError: | ||
self.sel_a = self.graph.get_operation_by_name('DescrptSeA').get_attr('sel_a') | ||
self.prod_env_mat_op = self.graph.get_operation_by_name ('DescrptSeA') | ||
else: | ||
raise RuntimeError("Unsupported descriptor") | ||
|
||
self.davg = get_tensor_by_name_from_graph(self.graph, f'descrpt_attr{self.suffix}/t_avg') | ||
self.dstd = get_tensor_by_name_from_graph(self.graph, f'descrpt_attr{self.suffix}/t_std') | ||
self.ntypes = get_tensor_by_name_from_graph(self.graph, 'descrpt_attr/ntypes') | ||
|
||
|
||
self.rcut = self.prod_env_mat_op.get_attr('rcut_r') | ||
self.rcut_smth = self.prod_env_mat_op.get_attr('rcut_r_smth') | ||
if isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): | ||
self.rcut = self.prod_env_mat_op.get_attr('rcut') | ||
self.rcut_smth = self.prod_env_mat_op.get_attr('rcut_smth') | ||
else: | ||
self.rcut = self.prod_env_mat_op.get_attr('rcut_r') | ||
self.rcut_smth = self.prod_env_mat_op.get_attr('rcut_r_smth') | ||
|
||
self.embedding_net_nodes = get_embedding_net_nodes_from_graph_def(self.graph_def, suffix=self.suffix) | ||
|
||
|
@@ -172,6 +192,21 @@ def build(self, | |
net = "filter_" + str(ii) + "_net_" + str(jj) | ||
self._build_lower(net, xx, idx, upper, lower, stride0, stride1, extrapolate) | ||
idx += 1 | ||
elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): | ||
xx = np.arange(lower, upper, stride0, dtype = self.data_type) | ||
xx = np.append(xx, np.arange(upper, extrapolate * upper, stride1, dtype = self.data_type)) | ||
xx = np.append(xx, np.array([extrapolate * upper], dtype = self.data_type)) | ||
self.nspline = int((upper - lower) / stride0 + (extrapolate * upper - upper) / stride1) | ||
for ii in range(self.table_size): | ||
if self.type_one_side or (ii // self.ntypes, ii % self.ntypes) not in self.exclude_types: | ||
if self.type_one_side: | ||
net = "filter_-1_net_" + str(ii) | ||
else: | ||
net = "filter_" + str(ii // self.ntypes) + "_net_" + str(ii % self.ntypes) | ||
self._build_lower(net, xx, ii, upper, lower, stride0, stride1, extrapolate) | ||
else: | ||
raise RuntimeError("Unsupported descriptor") | ||
|
||
return lower, upper | ||
|
||
def _build_lower(self, net, xx, idx, upper, lower, stride0, stride1, extrapolate): | ||
|
@@ -185,6 +220,9 @@ def _build_lower(self, net, xx, idx, upper, lower, stride0, stride1, extrapolate | |
elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeT): | ||
tt = np.full((self.nspline, self.last_layer_size), stride1) | ||
tt[int((lower - extrapolate * lower) / stride1) + 1:(int((lower - extrapolate * lower) / stride1) + int((upper - lower) / stride0)), :] = stride0 | ||
elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): | ||
tt = np.full((self.nspline, self.last_layer_size), stride1) | ||
tt[:int((upper - lower) / stride0), :] = stride0 | ||
else: | ||
raise RuntimeError("Unsupported descriptor") | ||
|
||
|
@@ -225,6 +263,18 @@ def _get_bias(self): | |
for jj in range(ii, self.ntypes): | ||
node = self.embedding_net_nodes[f"filter_type_all{self.suffix}/bias_{layer}_{ii}_{jj}"] | ||
bias["layer_" + str(layer)].append(tf.make_ndarray(node)) | ||
elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): | ||
if self.type_one_side: | ||
for ii in range(0, self.ntypes): | ||
node = self.embedding_net_nodes[f"filter_type_all{self.suffix}/bias_{layer}_{ii}"] | ||
bias["layer_" + str(layer)].append(tf.make_ndarray(node)) | ||
else: | ||
for ii in range(0, self.ntypes * self.ntypes): | ||
if (ii // self.ntypes, ii % self.ntypes) not in self.exclude_types: | ||
node = self.embedding_net_nodes[f"filter_type_{ii // self.ntypes}{self.suffix}/bias_{layer}_{ii % self.ntypes}"] | ||
bias["layer_" + str(layer)].append(tf.make_ndarray(node)) | ||
else: | ||
bias["layer_" + str(layer)].append(np.array([])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add an |
||
return bias | ||
|
||
def _get_matrix(self): | ||
|
@@ -248,6 +298,18 @@ def _get_matrix(self): | |
for jj in range(ii, self.ntypes): | ||
node = self.embedding_net_nodes[f"filter_type_all{self.suffix}/matrix_{layer}_{ii}_{jj}"] | ||
matrix["layer_" + str(layer)].append(tf.make_ndarray(node)) | ||
elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): | ||
if self.type_one_side: | ||
for ii in range(0, self.ntypes): | ||
node = self.embedding_net_nodes[f"filter_type_all{self.suffix}/matrix_{layer}_{ii}"] | ||
matrix["layer_" + str(layer)].append(tf.make_ndarray(node)) | ||
else: | ||
for ii in range(0, self.ntypes * self.ntypes): | ||
if (ii // self.ntypes, ii % self.ntypes) not in self.exclude_types: | ||
node = self.embedding_net_nodes[f"filter_type_{ii // self.ntypes}{self.suffix}/matrix_{layer}_{ii % self.ntypes}"] | ||
matrix["layer_" + str(layer)].append(tf.make_ndarray(node)) | ||
else: | ||
matrix["layer_" + str(layer)].append(np.array([])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add an |
||
return matrix | ||
|
||
# one-by-one executions | ||
|
@@ -317,6 +379,9 @@ def _get_env_mat_range(self, | |
var = np.square(sw / (min_nbor_dist * self.dstd[:, 1:4])) | ||
lower = np.min(-var) | ||
upper = np.max(var) | ||
elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): | ||
lower = np.min(-self.davg[:, 0] / self.dstd[:, 0]) | ||
upper = np.max(((1 / min_nbor_dist) * sw - self.davg[:, 0]) / self.dstd[:, 0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add an |
||
log.info('training data with lower boundary: ' + str(lower)) | ||
log.info('training data with upper boundary: ' + str(upper)) | ||
return math.floor(lower), math.ceil(upper) | ||
|
@@ -342,6 +407,10 @@ def _get_layer_size(self): | |
layer_size = len(self.embedding_net_nodes) // (self.ntypes * 2) | ||
elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeT): | ||
layer_size = len(self.embedding_net_nodes) // int(comb(self.ntypes + 1, 2) * 2) | ||
elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): | ||
layer_size = len(self.embedding_net_nodes) // ((self.ntypes * self.ntypes - len(self.exclude_types)) * 2) | ||
if self.type_one_side : | ||
layer_size = len(self.embedding_net_nodes) // (self.ntypes * 2) | ||
return layer_size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add an |
||
|
||
def _get_table_size(self): | ||
|
@@ -352,6 +421,10 @@ def _get_table_size(self): | |
table_size = self.ntypes | ||
elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeT): | ||
table_size = int(comb(self.ntypes + 1, 2)) | ||
elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): | ||
table_size = self.ntypes * self.ntypes | ||
if self.type_one_side : | ||
table_size = self.ntypes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK~ |
||
return table_size | ||
|
||
def _get_data_type(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
DescrptSeA
andDescrptSeT
are the same here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, they are the same. But considering the possible changes later, I have explicitly distinguished them.