Skip to content

Commit

Permalink
add ut for model compression (#586)
Browse files Browse the repository at this point in the history
* add ut for model compression

* fix bug of UT

* fix bug of UT

* fix bug of UT

* fix bug of UT

* fix bug of UT

* fix bug of UT

* fix bug of UT

* fix bug of UT

* fix bug of UT

* fix bug of UT

* fix bug of UT

* fix bug of UT

* fix bug of UT

* adjust code structure model compression
  • Loading branch information
denghuilu committed May 9, 2021
1 parent bcdd9f6 commit 33c0c0f
Show file tree
Hide file tree
Showing 13 changed files with 685 additions and 9 deletions.
5 changes: 5 additions & 0 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,8 @@ def _concat_type_embedding(

def _filter_lower(
self,
type_i,
type_input,
start_index,
incrs_index,
inputs,
Expand Down Expand Up @@ -638,9 +640,11 @@ def _filter(
outputs_size_2 = self.n_axis_neuron
with tf.variable_scope(name, reuse=reuse):
start_index = 0
type_i = 0
if type_embedding is None:
for type_i in range(self.ntypes):
ret = self._filter_lower(
type_i, type_input,
start_index, self.sel_a[type_i],
inputs,
nframes,
Expand All @@ -660,6 +664,7 @@ def _filter(
start_index += self.sel_a[type_i]
else :
xyz_scatter_1 = self._filter_lower(
type_i, type_input,
start_index, np.cumsum(self.sel_a)[-1],
inputs,
nframes,
Expand Down
12 changes: 7 additions & 5 deletions deepmd/entrypoints/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def compress(
jdata = j_loader(INPUT)
if "model" not in jdata.keys():
jdata = convert_input_v0_v1(jdata, warning=True, dump="input_v1_compat.json")
jdata = normalize(jdata)
jdata["model"]["compress"] = {}
jdata["model"]["compress"]["type"] = 'se_e2_a'
jdata["model"]["compress"]["compress"] = True
jdata["model"]["compress"]["model_file"] = input
jdata["model"]["compress"]["table_config"] = [
Expand All @@ -75,20 +75,22 @@ def compress(
10 * stride,
int(frequency),
]
# be careful here, if one want to refine the model
jdata["training"]["numb_steps"] = jdata["training"]["save_freq"]
jdata = normalize(jdata)


# check the descriptor info of the input file
assert (
jdata["model"]["descriptor"]["type"] == "se_a"
), "Model compression error: descriptor type must be se_a!"
jdata["model"]["descriptor"]["type"] == "se_a" or jdata["model"]["descriptor"]["type"] == "se_e2_a"
), "Model compression error: descriptor type must be se_a or se_e2_a!"
assert (
jdata["model"]["descriptor"]["resnet_dt"] is False
), "Model compression error: descriptor resnet_dt must be false!"

# stage 1: training or refining the model with tabulation
log.info("\n\n")
log.info("stage 1: train or refine the model with tabulation")
# be careful here, if one want to refine the model
jdata["training"]["stop_batch"] = jdata["training"]["save_freq"]
control_file = "compress.json"
with open(control_file, "w") as fp:
json.dump(jdata, fp, indent=4)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def _build_lower(
1,
activation_fn = None,
bavg = bias_atom_e,
name='final_layer_'+suffix,
name='final_layer'+suffix,
reuse=reuse,
seed = self.seed,
precision = self.fitting_precision,
Expand Down
25 changes: 25 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,29 @@ def modifier_variant_type_args():
optional = False,
doc = doc_modifier_type)

# --- model compression configurations: --- #
def model_compression():
doc_compress = "The name of the frozen model file."
doc_model_file = f"The input model file, which will be compressed by the DeePMD-kit."
doc_table_config = f"The arguments of model compression, including extrapolate(scale of model extrapolation), stride(uniform stride of tabulation's first and second table), and frequency(frequency of tabulation overflow check)."

return [
Argument("compress", bool, optional = False, default = True, doc = doc_compress),
Argument("model_file", str, optional = False, default = 'frozen_model.pb', doc = doc_model_file),
Argument("table_config", list, optional = False, default = [5, 0.01, 0.1, -1], doc = doc_table_config),
]

# --- model compression configurations: --- #
def model_compression_type_args():
doc_compress_type = "The type of model compression, which should be consistent with the descriptor type."

return Variant("type", [
Argument("se_e2_a", dict, model_compression(), alias = ['se_a'])
],
optional = True,
default_tag = 'se_e2_a',
doc = doc_compress_type)


def model_args ():
doc_type_map = 'A list of strings. Give the name to each type of atoms.'
Expand All @@ -333,6 +356,7 @@ def model_args ():
doc_smin_alpha = 'The short-range tabulated interaction will be swithed according to the distance of the nearest neighbor. This distance is calculated by softmin. This parameter is the decaying parameter in the softmin. It is only required when `use_srtab` is provided.'
doc_sw_rmin = 'The lower boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided.'
doc_sw_rmax = 'The upper boundary of the interpolation between short-range tabulated interaction and DP. It is only required when `use_srtab` is provided.'
doc_compress_config = 'Model compression configurations'

ca = Argument("model", dict,
[Argument("type_map", list, optional = True, doc = doc_type_map),
Expand All @@ -346,6 +370,7 @@ def model_args ():
Argument("descriptor", dict, [], [descrpt_variant_type_args()], doc = doc_descrpt),
Argument("fitting_net", dict, [], [fitting_variant_type_args()], doc = doc_fitting),
Argument("modifier", dict, [], [modifier_variant_type_args()], optional = True, doc = doc_modifier),
Argument("compress", dict, [], [model_compression_type_args()], optional = True, doc = doc_compress_config)
])
# print(ca.gen_doc())
return ca
Expand Down
10 changes: 7 additions & 3 deletions deepmd/utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,17 @@ def __init__(self,
self.sub_graph, self.sub_graph_def = self._load_sub_graph()
self.sub_sess = tf.Session(graph = self.sub_graph)

self.sel_a = self.graph.get_operation_by_name('DescrptSeA').get_attr('sel_a')
try:
self.sel_a = self.graph.get_operation_by_name('ProdEnvMatA').get_attr('sel_a')
self.descrpt = self.graph.get_operation_by_name ('ProdEnvMatA')
except Exception:
self.sel_a = self.graph.get_operation_by_name('DescrptSeA').get_attr('sel_a')
self.descrpt = self.graph.get_operation_by_name ('DescrptSeA')
self.ntypes = self._get_tensor_value(self.graph.get_tensor_by_name ('descrpt_attr/ntypes:0'))

self.davg = self._get_tensor_value(self.graph.get_tensor_by_name ('descrpt_attr/t_avg:0'))
self.dstd = self._get_tensor_value(self.graph.get_tensor_by_name ('descrpt_attr/t_std:0'))

self.descrpt = self.graph.get_operation_by_name ('DescrptSeA')

self.rcut = self.descrpt.get_attr('rcut_r')
self.rcut_smth = self.descrpt.get_attr('rcut_r_smth')

Expand Down
Binary file added source/tests/model_compression/data/set.000/box.npy
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
192 changes: 192 additions & 0 deletions source/tests/model_compression/data/type.raw
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
0
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
2 changes: 2 additions & 0 deletions source/tests/model_compression/data/type_map.raw
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
O
H

0 comments on commit 33c0c0f

Please sign in to comment.