Skip to content

Commit

Permalink
Enable mixed precision support for deepmd-kit (#1285)
Browse files Browse the repository at this point in the history
* enable mixed precision support for dp

* set the default embedding net & fitting net precision

* add doc for mixed precision

* fix typo

* fix UT bug

* use input script to control the mixed precision workflow

* add tf version check for mixed precision

* Update training-advanced.md

* fix typo

* fix TF_VERSION control

* fix TF_VERSION comparison

* enable mixed precision for hybrid descriptor

* Update network.py

* use parameter to control the network mixed precision output precision

* add example for mixed precision training workflow

* fix lint errors
  • Loading branch information
denghuilu committed Nov 23, 2021
1 parent 9c517eb commit f40e14e
Show file tree
Hide file tree
Showing 13 changed files with 273 additions and 19 deletions.
18 changes: 18 additions & 0 deletions deepmd/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,24 @@ def enable_compression(self,
raise NotImplementedError(
"Descriptor %s doesn't support compression!" % type(self).__name__)

def enable_mixed_precision(self, mixed_prec: dict = None) -> None:
"""
Reveive the mixed precision setting.
Parameters
----------
mixed_prec
The mixed precision setting used in the embedding net
Notes
-----
This method is called by others when the descriptor supported compression.
"""
raise NotImplementedError(
"Descriptor %s doesn't support mixed precision training!"
% type(self).__name__
)

@abstractmethod
def prod_force_virial(self,
atom_ener: tf.Tensor,
Expand Down
14 changes: 14 additions & 0 deletions deepmd/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,20 @@ def enable_compression(self,
for idx, ii in enumerate(self.descrpt_list):
ii.enable_compression(min_nbor_dist, model_file, table_extrapolate, table_stride_1, table_stride_2, check_frequency, suffix=f"{suffix}_{idx}")


def enable_mixed_precision(self, mixed_prec : dict = None) -> None:
"""
Reveive the mixed precision setting.
Parameters
----------
mixed_prec
The mixed precision setting used in the embedding net
"""
for idx, ii in enumerate(self.descrpt_list):
ii.enable_mixed_precision(mixed_prec)


def init_variables(self,
model_file : str,
suffix : str = "",
Expand Down
18 changes: 17 additions & 1 deletion deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def __init__ (self,
self.davg = None
self.compress = False
self.embedding_net_variables = None
self.mixed_prec = None
self.place_holders = {}
nei_type = np.array([])
for ii in range(self.ntypes):
Expand Down Expand Up @@ -348,6 +349,18 @@ def enable_compression(self,
self.dstd = get_tensor_by_name_from_graph(graph, 'descrpt_attr%s/t_std' % suffix)


def enable_mixed_precision(self, mixed_prec : dict = None) -> None:
"""
Reveive the mixed precision setting.
Parameters
----------
mixed_prec
The mixed precision setting used in the embedding net
"""
self.mixed_prec = mixed_prec
self.filter_precision = get_precision(mixed_prec['output_prec'])


def build (self,
coord_ : tf.Tensor,
Expand Down Expand Up @@ -708,7 +721,8 @@ def _filter_lower(
seed = self.seed,
trainable = trainable,
uniform_seed = self.uniform_seed,
initial_variables = self.embedding_net_variables)
initial_variables = self.embedding_net_variables,
mixed_prec = self.mixed_prec)
if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
else:
# we can safely return the final xyz_scatter filled with zero directly
Expand All @@ -735,6 +749,8 @@ def _filter(
name='linear',
reuse=None,
trainable = True):
if self.mixed_prec is not None:
inputs = tf.cast(inputs, get_precision(self.mixed_prec['compute_prec']))
nframes = tf.shape(tf.reshape(inputs, [-1, natoms[0], self.ndescrpt]))[0]
# natom x (nei x 4)
shape = inputs.get_shape().as_list()
Expand Down
1 change: 1 addition & 0 deletions deepmd/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"TRANSFER_PATTERN",
"FITTING_NET_PATTERN",
"EMBEDDING_NET_PATTERN",
"TF_VERSION"
]

SHARED_LIB_MODULE = "op"
Expand Down
22 changes: 18 additions & 4 deletions deepmd/fit/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__ (self,
self.dim_rot_mat = self.dim_rot_mat_1 * 3
self.useBN = False
self.fitting_net_variables = None
self.mixed_prec = None

def get_sel_type(self) -> int:
"""
Expand Down Expand Up @@ -141,12 +142,12 @@ def build (self,
layer = inputs_i
for ii in range(0,len(self.n_neuron)) :
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables)
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables, mixed_prec = self.mixed_prec)
else :
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables)
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables, mixed_prec = self.mixed_prec)
if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
# (nframes x natoms) x naxis
final_layer = one_layer(layer, self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables)
final_layer = one_layer(layer, self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables, mixed_prec = self.mixed_prec, final_layer = True)
if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
# (nframes x natoms) x 1 * naxis
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0] * natoms[2+type_i], 1, self.dim_rot_mat_1])
Expand Down Expand Up @@ -177,4 +178,17 @@ def init_variables(self,
model_file : str
The input frozen model file
"""
self.fitting_net_variables = get_fitting_net_variables(model_file)
self.fitting_net_variables = get_fitting_net_variables(model_file)


def enable_mixed_precision(self, mixed_prec : dict = None) -> None:
"""
Reveive the mixed precision setting.
Parameters
----------
mixed_prec
The mixed precision setting used in the embedding net
"""
self.mixed_prec = mixed_prec
self.fitting_precision = get_precision(mixed_prec['output_prec'])
26 changes: 22 additions & 4 deletions deepmd/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__ (self,
self.aparam_inv_std = None

self.fitting_net_variables = None
self.mixed_prec = None

def get_numb_fparam(self) -> int:
"""
Expand Down Expand Up @@ -293,7 +294,8 @@ def _build_lower(
precision = self.fitting_precision,
trainable = self.trainable[ii],
uniform_seed = self.uniform_seed,
initial_variables = self.fitting_net_variables)
initial_variables = self.fitting_net_variables,
mixed_prec = self.mixed_prec)
else :
layer = one_layer(
layer,
Expand All @@ -305,7 +307,8 @@ def _build_lower(
precision = self.fitting_precision,
trainable = self.trainable[ii],
uniform_seed = self.uniform_seed,
initial_variables = self.fitting_net_variables)
initial_variables = self.fitting_net_variables,
mixed_prec = self.mixed_prec)
if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
final_layer = one_layer(
layer,
Expand All @@ -318,7 +321,9 @@ def _build_lower(
precision = self.fitting_precision,
trainable = self.trainable[-1],
uniform_seed = self.uniform_seed,
initial_variables = self.fitting_net_variables)
initial_variables = self.fitting_net_variables,
mixed_prec = self.mixed_prec,
final_layer = True)
if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift

return final_layer
Expand Down Expand Up @@ -494,4 +499,17 @@ def init_variables(self,
model_file : str
The input frozen model file
"""
self.fitting_net_variables = get_fitting_net_variables(model_file)
self.fitting_net_variables = get_fitting_net_variables(model_file)


def enable_mixed_precision(self, mixed_prec : dict = None) -> None:
"""
Reveive the mixed precision setting.
Parameters
----------
mixed_prec
The mixed precision setting used in the embedding net
"""
self.mixed_prec = mixed_prec
self.fitting_precision = get_precision(mixed_prec['output_prec'])
35 changes: 30 additions & 5 deletions deepmd/fit/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def build (self,
else :
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision)
# (nframes x natoms) x 9
final_layer = one_layer(layer, 9, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision)
final_layer = one_layer(layer, 9, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, final_layer = True)
# (nframes x natoms) x 3 x 3
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0] * natoms[2+type_i], 3, 3])
# (nframes x natoms) x 3 x 3
Expand Down Expand Up @@ -194,6 +194,7 @@ def __init__ (self,
self.dim_rot_mat = self.dim_rot_mat_1 * 3
self.useBN = False
self.fitting_net_variables = None
self.mixed_prec = None

def get_sel_type(self) -> List[int]:
"""
Expand Down Expand Up @@ -324,17 +325,17 @@ def build (self,
layer = inputs_i
for ii in range(0,len(self.n_neuron)) :
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables)
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables, mixed_prec = self.mixed_prec)
else :
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables)
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables, mixed_prec = self.mixed_prec)
if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
if self.fit_diag :
bavg = np.zeros(self.dim_rot_mat_1)
# bavg[0] = self.avgeig[0]
# bavg[1] = self.avgeig[1]
# bavg[2] = self.avgeig[2]
# (nframes x natoms) x naxis
final_layer = one_layer(layer, self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, bavg = bavg, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables)
final_layer = one_layer(layer, self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, bavg = bavg, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables, mixed_prec = self.mixed_prec, final_layer = True)
if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
# (nframes x natoms) x naxis
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0] * natoms[2+type_i], self.dim_rot_mat_1])
Expand All @@ -346,7 +347,7 @@ def build (self,
# bavg[1*self.dim_rot_mat_1+1] = self.avgeig[1]
# bavg[2*self.dim_rot_mat_1+2] = self.avgeig[2]
# (nframes x natoms) x (naxis x naxis)
final_layer = one_layer(layer, self.dim_rot_mat_1*self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, bavg = bavg, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables)
final_layer = one_layer(layer, self.dim_rot_mat_1*self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, bavg = bavg, precision = self.fitting_precision, uniform_seed = self.uniform_seed, initial_variables = self.fitting_net_variables, mixed_prec = self.mixed_prec, final_layer = True)
if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
# (nframes x natoms) x naxis x naxis
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0] * natoms[2+type_i], self.dim_rot_mat_1, self.dim_rot_mat_1])
Expand Down Expand Up @@ -387,6 +388,19 @@ def init_variables(self,
self.fitting_net_variables = get_fitting_net_variables(model_file)


def enable_mixed_precision(self, mixed_prec : dict = None) -> None:
"""
Reveive the mixed precision setting.
Parameters
----------
mixed_prec
The mixed precision setting used in the embedding net
"""
self.mixed_prec = mixed_prec
self.fitting_precision = get_precision(mixed_prec['output_prec'])


class GlobalPolarFittingSeA () :
"""
Fit the system polarizability with descriptor se_a
Expand Down Expand Up @@ -509,3 +523,14 @@ def init_variables(self,
"""
self.polar_fitting.init_variables(model_file)


def enable_mixed_precision(self, mixed_prec : dict = None) -> None:
"""
Reveive the mixed precision setting.
Parameters
----------
mixed_prec
The mixed precision setting used in the embedding net
"""
self.polar_fitting.enable_mixed_precision(mixed_prec)
28 changes: 26 additions & 2 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import shutil
import google.protobuf.message
import numpy as np
from packaging.version import Version

from deepmd.env import tf
from deepmd.env import get_tf_session_config
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
Expand All @@ -23,13 +25,13 @@
from deepmd.utils.graph import get_tensor_by_name

from tensorflow.python.client import timeline
from deepmd.env import op_module
from deepmd.env import op_module, TF_VERSION
from deepmd.utils.errors import GraphWithoutTensorError

# load grad of force module
import deepmd.op

from deepmd.common import j_must_have, ClassArg, data_requirement
from deepmd.common import j_must_have, ClassArg, data_requirement, get_precision

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -227,6 +229,13 @@ def _init_param(self, jdata):
self.tensorboard = self.run_opt.is_chief and tr_data.get('tensorboard', False)
self.tensorboard_log_dir = tr_data.get('tensorboard_log_dir', 'log')
self.tensorboard_freq = tr_data.get('tensorboard_freq', 1)
self.mixed_prec = tr_data.get('mixed_precision', None)
if self.mixed_prec is not None:
if (self.mixed_prec['compute_prec'] != 'float16' or self.mixed_prec['output_prec'] != 'float32'):
raise RuntimeError(
"Unsupported mixed precision option [output_prec, compute_prec]: [%s, %s], "
" Supported: [float32, float16], Please set mixed precision option correctly!"
% (self.mixed_prec['output_prec'], self.mixed_prec['compute_prec']))
# self.sys_probs = tr_data['sys_probs']
# self.auto_prob_style = tr_data['auto_prob']
self.useBN = False
Expand Down Expand Up @@ -289,6 +298,10 @@ def build (self,
tf.constant("compressed_model", name = 'model_type', dtype = tf.string)
else:
tf.constant("original_model", name = 'model_type', dtype = tf.string)

if self.mixed_prec is not None:
self.descrpt.enable_mixed_precision(self.mixed_prec)
self.fitting.enable_mixed_precision(self.mixed_prec)

self._build_lr()
self._build_network(data)
Expand Down Expand Up @@ -332,6 +345,8 @@ def _build_network(self, data):
self.place_holders,
suffix = "test")

if self.mixed_prec is not None:
self.l2_l = tf.cast(self.l2_l, get_precision(self.mixed_prec['output_prec']))
log.info("built network")

def _build_training(self):
Expand All @@ -345,6 +360,15 @@ def _build_training(self):
optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer)
else:
optimizer = tf.train.AdamOptimizer(learning_rate = self.learning_rate)
if self.mixed_prec is not None:
_TF_VERSION = Version(TF_VERSION)
# check the TF_VERSION, when TF < 1.12, mixed precision is not allowed
if _TF_VERSION < Version('1.12.0'):
raise RuntimeError("TensorFlow version %s is not compatible with the mixed precision setting. Please consider upgrading your TF version!" % TF_VERSION)
elif _TF_VERSION < Version('2.4.0'):
optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)
else:
optimizer = tf.mixed_precision.enable_mixed_precision_graph_rewrite(optimizer)
apply_op = optimizer.minimize(loss=self.l2_l,
global_step=self.global_step,
var_list=trainable_variables,
Expand Down
Loading

0 comments on commit f40e14e

Please sign in to comment.