diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 21f078b763..2b3f8a249c 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -372,7 +372,7 @@ def _build_training(self): if self.mixed_prec is not None: _TF_VERSION = Version(TF_VERSION) # check the TF_VERSION, when TF < 1.12, mixed precision is not allowed - if _TF_VERSION < Version('1.12.0'): + if _TF_VERSION < Version('1.14.0'): raise RuntimeError("TensorFlow version %s is not compatible with the mixed precision setting. Please consider upgrading your TF version!" % TF_VERSION) elif _TF_VERSION < Version('2.4.0'): optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer) diff --git a/deepmd/utils/network.py b/deepmd/utils/network.py index 57dd90f893..befd571f24 100644 --- a/deepmd/utils/network.py +++ b/deepmd/utils/network.py @@ -79,15 +79,13 @@ def one_layer(inputs, if use_timestep : if mixed_prec is not None and not final_layer: idt = tf.cast(idt, get_precision(mixed_prec['compute_prec'])) - return tf.reshape(activation_fn(hidden), [-1, outputs_size]) * idt + hidden = tf.reshape(activation_fn(hidden), [-1, outputs_size]) * idt else : - return tf.reshape(activation_fn(hidden), [-1, outputs_size]) - else: - if useBN: - None - # return self._batch_norm(hidden, name=name+'_normalization', reuse=reuse) - else: - return hidden + hidden = tf.reshape(activation_fn(hidden), [-1, outputs_size]) + + if mixed_prec is not None: + hidden = tf.cast(hidden, get_precision(mixed_prec['output_prec'])) + return hidden def embedding_net_rand_seed_shift( @@ -237,6 +235,8 @@ def embedding_net(xx, xx = tf.concat([xx,xx], 1) + hidden else: xx = hidden + if mixed_prec is not None: + xx = tf.cast(xx, get_precision(mixed_prec['output_prec'])) return xx def variable_summaries(var: tf.Variable, name: str): diff --git a/source/tests/test_mixed_prec_training.py b/source/tests/test_mixed_prec_training.py new file mode 100644 index 0000000000..28a1d485f7 --- /dev/null +++ b/source/tests/test_mixed_prec_training.py @@ -0,0 +1,60 @@ +import os,json +import numpy as np +import unittest +import subprocess as sp +from packaging.version import Version + +from deepmd.infer import DeepPot +# from deepmd.entrypoints.compress import compress +from common import j_loader, tests_path +from deepmd.env import TF_VERSION + + +def _file_delete(file) : + if os.path.isdir(file): + os.rmdir(file) + elif os.path.isfile(file): + os.remove(file) + +def _subprocess_run(command): + popen = sp.Popen(command.split(), shell=False, stdout=sp.PIPE, stderr=sp.STDOUT) + for line in iter(popen.stdout.readline, b''): + if hasattr(line, 'decode'): + line = line.decode('utf-8') + line = line.rstrip() + print(line) + popen.wait() + return popen.returncode + +class TestMixedPrecTraining(unittest.TestCase): + def setUp(self): + data_file = str(tests_path / os.path.join("model_compression", "data")) + self.INPUT = str(tests_path / "input.json") + jdata = j_loader(str(tests_path / os.path.join("model_compression", "input.json"))) + jdata["training"]["training_data"]["systems"] = data_file + jdata["training"]["validation_data"]["systems"] = data_file + jdata["training"]["mixed_precision"] = {} + jdata["training"]["mixed_precision"]["compute_prec"] = "float16" + jdata["training"]["mixed_precision"]["output_prec"] = "float32" + with open(self.INPUT, "w") as fp: + json.dump(jdata, fp, indent=4) + + def test_training(self): + _TF_VERSION = Version(TF_VERSION) + # check the TF_VERSION, when TF < 1.12, mixed precision is not allowed + if _TF_VERSION >= Version('1.14.0'): + ret = _subprocess_run("dp train " + self.INPUT) + np.testing.assert_equal(ret, 0, 'DP train failed!') + + def tearDown(self): + _file_delete(self.INPUT) + _file_delete("out.json") + _file_delete("checkpoint") + _file_delete("model.ckpt.meta") + _file_delete("model.ckpt.index") + _file_delete("model.ckpt.data-00000-of-00001") + _file_delete("model.ckpt-100.meta") + _file_delete("model.ckpt-100.index") + _file_delete("model.ckpt-100.data-00000-of-00001") + _file_delete("input_v2_compat.json") + _file_delete("lcurve.out")