Skip to content

Commit

Permalink
support dp compress
Browse files Browse the repository at this point in the history
Set `dp_compress` to `true` in parameters will enable model compression.
  • Loading branch information
njzjz committed Nov 30, 2021
1 parent df9095d commit fcf4574
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,7 @@ The bold notation of key (such aas **type_map**) means that it's a necessary key
| training_iter0_model_path | list of string | ["/path/to/model0_ckpt/", ...] | The model used to init the first iter training. Number of element should be equal to `numb_models` |
| training_init_model | bool | False | Iteration > 0, the model parameters will be initilized from the model trained at the previous iteration. Iteration == 0, the model parameters will be initialized from `training_iter0_model_path`. |
| **default_training_param** | Dict | | Training parameters for `deepmd-kit` in `00.train`. <br /> You can find instructions from here: (https://github.com/deepmodeling/deepmd-kit)..<br /> |
| dp_compress | bool | false | Use `dp compress` to compress the model. Default is false. |
| *#Exploration*
| **model_devi_dt** | Float | 0.002 (recommend) | Timestep for MD |
| **model_devi_skip** | Integer | 0 | Number of structures skipped for fp in each MD
Expand Down
10 changes: 9 additions & 1 deletion dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,8 @@ def run_train (iter_index,
commands.append(command)
command = '%s freeze' % train_command
commands.append(command)
if jdata.get("dp_compress", False):
commands.append("%s compress" % train_command)
else:
raise RuntimeError("DP-GEN currently only supports for DeePMD-kit 1.x or 2.x version!" )

Expand All @@ -536,6 +538,8 @@ def run_train (iter_index,
]
backward_files = ['frozen_model.pb', 'lcurve.out', 'train.log']
backward_files+= ['model.ckpt.meta', 'model.ckpt.index', 'model.ckpt.data-00000-of-00001', 'checkpoint']
if jdata.get("dp_compress", False):
backward_files.append('frozen_model_compressed.pb')
init_data_sys_ = jdata['init_data_sys']
init_data_sys = []
for ii in init_data_sys_ :
Expand Down Expand Up @@ -621,7 +625,11 @@ def post_train (iter_index,
return
# symlink models
for ii in range(numb_models) :
task_file = os.path.join(train_task_fmt % ii, 'frozen_model.pb')
if not jdata.get("dp_compress", False):
model_name = 'frozen_model.pb'
else:
model_name = 'frozen_model_compressed.pb'
task_file = os.path.join(train_task_fmt % ii, model_name)
ofile = os.path.join(work_path, 'graph.%03d.pb' % ii)
if os.path.isfile(ofile) :
os.remove(ofile)
Expand Down

0 comments on commit fcf4574

Please sign in to comment.