Skip to content

Commit

Permalink
breaking: remove multi-task support in tf (#3763)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Removed the `multi_task` parameter across various descriptor
initialization methods, streamlining the setup process.
- Introduced a new option `--head` for specifying a model branch to
freeze in multi-task mode.
- **Bug Fixes**
- Corrected initialization and training processes by removing outdated
multi-task functionalities.
- **Documentation**
- Updated guides on model freezing and training to reflect the removal
of multi-task functionalities and the shift towards using the PyTorch
backend.
- **Refactor**
- Eliminated redundant code and simplified parameter assignments in
training scripts.
- **Chores**
- Removed unused dictionaries and outdated code across several modules
to clean up the codebase.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
iProzd and coderabbitai[bot] committed May 11, 2024
1 parent 74dce7f commit 063de8a
Show file tree
Hide file tree
Showing 29 changed files with 183 additions and 3,115 deletions.
2 changes: 1 addition & 1 deletion deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def __init__(
# consistent with argcheck, not used though
seed: Optional[int] = None,
) -> None:
## seed, uniform_seed, multi_task, not included.
## seed, uniform_seed, not included.
# Ensure compatibility with the deprecated stripped_type_embedding option.
if stripped_type_embedding is not None:
# Use the user-set stripped_type_embedding parameter first
Expand Down
4 changes: 1 addition & 3 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ class DescrptSeA(NativeOP, BaseDescriptor):
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision
The precision of the embedding net parameters. Supported options are |PRECISION|
multi_task
If the model has multi fitting nets to train.
spin
The deepspin object.
Expand Down Expand Up @@ -159,7 +157,7 @@ def __init__(
# consistent with argcheck, not used though
seed: Optional[int] = None,
) -> None:
## seed, uniform_seed, multi_task, not included.
## seed, uniform_seed, not included.
if spin is not None:
raise NotImplementedError("spin is not implemented")

Expand Down
4 changes: 1 addition & 3 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ class DescrptSeR(NativeOP, BaseDescriptor):
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision
The precision of the embedding net parameters. Supported options are |PRECISION|
multi_task
If the model has multi fitting nets to train.
spin
The deepspin object.
Expand Down Expand Up @@ -114,7 +112,7 @@ def __init__(
# consistent with argcheck, not used though
seed: Optional[int] = None,
) -> None:
## seed, uniform_seed, multi_task, not included.
## seed, uniform_seed, not included.
if not type_one_side:
raise NotImplementedError("type_one_side == False not implemented")
if spin is not None:
Expand Down
6 changes: 0 additions & 6 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,6 @@ def main_parser() -> argparse.ArgumentParser:
default=None,
help="(Supported backend: TensorFlow) the name of weight file (.npy), if set, save the model's weight into the file",
)
parser_frz.add_argument(
"--united-model",
action="store_true",
default=False,
help="(Supported backend: TensorFlow) When in multi-task mode, freeze all nodes into one united model",
)
parser_frz.add_argument(
"--head",
default=None,
Expand Down
4 changes: 1 addition & 3 deletions deepmd/tf/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class DescrptHybrid(Descriptor):
def __init__(
self,
list: List[Union[Descriptor, Dict[str, Any]]],
multi_task: bool = False,
ntypes: Optional[int] = None,
spin: Optional[Spin] = None,
**kwargs,
Expand All @@ -59,13 +58,12 @@ def __init__(
"cannot build descriptor from an empty list of descriptors."
)
formatted_descript_list = []
self.multi_task = multi_task
for ii in descrpt_list:
if isinstance(ii, Descriptor):
formatted_descript_list.append(ii)
elif isinstance(ii, dict):
formatted_descript_list.append(
Descriptor(**ii, ntypes=ntypes, spin=spin, multi_task=multi_task)
Descriptor(**ii, ntypes=ntypes, spin=spin)
)
else:
raise NotImplementedError
Expand Down
35 changes: 8 additions & 27 deletions deepmd/tf/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ class DescrptSeA(DescrptSe):
The precision of the embedding net parameters. Supported options are |PRECISION|
uniform_seed
Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
multi_task
If the model has multi fitting nets to train.
env_protection: float
Protection parameter to prevent division by zero errors during environment matrix calculations.
Expand Down Expand Up @@ -181,7 +179,6 @@ def __init__(
activation_function: str = "tanh",
precision: str = "default",
uniform_seed: bool = False,
multi_task: bool = False,
spin: Optional[Spin] = None,
tebd_input_mode: str = "concat",
env_protection: float = 0.0, # not implement!!
Expand Down Expand Up @@ -304,15 +301,6 @@ def __init__(
self.stat_descrpt *= tf.reshape(mask, tf.shape(self.stat_descrpt))
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)
self.original_sel = None
self.multi_task = multi_task
if multi_task:
self.stat_dict = {
"sumr": [],
"suma": [],
"sumn": [],
"sumr2": [],
"suma2": [],
}

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -392,21 +380,14 @@ def compute_input_stats(
sumn.append(sysn)
sumr2.append(sysr2)
suma2.append(sysa2)
if not self.multi_task:
stat_dict = {
"sumr": sumr,
"suma": suma,
"sumn": sumn,
"sumr2": sumr2,
"suma2": suma2,
}
self.merge_input_stats(stat_dict)
else:
self.stat_dict["sumr"] += sumr
self.stat_dict["suma"] += suma
self.stat_dict["sumn"] += sumn
self.stat_dict["sumr2"] += sumr2
self.stat_dict["suma2"] += suma2
stat_dict = {
"sumr": sumr,
"suma": suma,
"sumn": sumn,
"sumr2": sumr2,
"suma2": suma2,
}
self.merge_input_stats(stat_dict)

def merge_input_stats(self, stat_dict):
"""Merge the statisitcs computed from compute_input_stats to obtain the self.davg and self.dstd.
Expand Down
2 changes: 0 additions & 2 deletions deepmd/tf/descriptor/se_a_ebd_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(
activation_function: str = "tanh",
precision: str = "default",
uniform_seed: bool = False,
multi_task: bool = False,
spin: Optional[Spin] = None,
**kwargs,
) -> None:
Expand All @@ -63,7 +62,6 @@ def __init__(
activation_function=activation_function,
precision=precision,
uniform_seed=uniform_seed,
multi_task=multi_task,
spin=spin,
tebd_input_mode="strip",
**kwargs,
Expand Down
28 changes: 8 additions & 20 deletions deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ class DescrptSeAtten(DescrptSeA):
Whether to mask the diagonal in the attention weights.
ln_eps: float, Optional
The epsilon value for layer normalization.
multi_task: bool
If the model has multi fitting nets to train.
tebd_input_mode: str
The input mode of the type embedding. Supported modes are ["concat", "strip"].
- "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network.
Expand Down Expand Up @@ -188,7 +186,6 @@ def __init__(
attn_layer: int = 2,
attn_dotr: bool = True,
attn_mask: bool = False,
multi_task: bool = False,
smooth_type_embedding: bool = False,
tebd_input_mode: str = "concat",
# not implemented
Expand Down Expand Up @@ -246,7 +243,6 @@ def __init__(
activation_function=activation_function,
precision=precision,
uniform_seed=uniform_seed,
multi_task=multi_task,
)
"""
Constructor
Expand Down Expand Up @@ -403,21 +399,14 @@ def compute_input_stats(
sumn.append(sysn)
sumr2.append(sysr2)
suma2.append(sysa2)
if not self.multi_task:
stat_dict = {
"sumr": sumr,
"suma": suma,
"sumn": sumn,
"sumr2": sumr2,
"suma2": suma2,
}
self.merge_input_stats(stat_dict)
else:
self.stat_dict["sumr"] += sumr
self.stat_dict["suma"] += suma
self.stat_dict["sumn"] += sumn
self.stat_dict["sumr2"] += sumr2
self.stat_dict["suma2"] += suma2
stat_dict = {
"sumr": sumr,
"suma": suma,
"sumn": sumn,
"sumr2": sumr2,
"suma2": suma2,
}
self.merge_input_stats(stat_dict)

def enable_compression(
self,
Expand Down Expand Up @@ -2117,7 +2106,6 @@ def __init__(
attn_layer=attn_layer,
attn_dotr=attn_dotr,
attn_mask=attn_mask,
multi_task=True,
trainable_ln=trainable_ln,
ln_eps=ln_eps,
smooth_type_embedding=smooth_type_embedding,
Expand Down
4 changes: 0 additions & 4 deletions deepmd/tf/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ class DescrptSeAttenV2(DescrptSeAtten):
Whether to dot the relative coordinates on the attention weights as a gated scheme.
attn_mask
Whether to mask the diagonal in the attention weights.
multi_task
If the model has multi fitting nets to train.
"""

def __init__(
Expand All @@ -84,7 +82,6 @@ def __init__(
attn_layer: int = 2,
attn_dotr: bool = True,
attn_mask: bool = False,
multi_task: bool = False,
**kwargs,
) -> None:
DescrptSeAtten.__init__(
Expand All @@ -108,7 +105,6 @@ def __init__(
attn_layer=attn_layer,
attn_dotr=attn_dotr,
attn_mask=attn_mask,
multi_task=multi_task,
tebd_input_mode="strip",
smooth_type_embedding=True,
**kwargs,
Expand Down
13 changes: 2 additions & 11 deletions deepmd/tf/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def __init__(
activation_function: str = "tanh",
precision: str = "default",
uniform_seed: bool = False,
multi_task: bool = False,
spin: Optional[Spin] = None,
env_protection: float = 0.0, # not implement!!
**kwargs,
Expand Down Expand Up @@ -211,9 +210,6 @@ def __init__(
self.sub_sess = tf.Session(
graph=sub_graph, config=default_tf_session_config
)
self.multi_task = multi_task
if multi_task:
self.stat_dict = {"sumr": [], "sumn": [], "sumr2": []}

def get_rcut(self):
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -282,13 +278,8 @@ def compute_input_stats(
sumr.append(sysr)
sumn.append(sysn)
sumr2.append(sysr2)
if not self.multi_task:
stat_dict = {"sumr": sumr, "sumn": sumn, "sumr2": sumr2}
self.merge_input_stats(stat_dict)
else:
self.stat_dict["sumr"] += sumr
self.stat_dict["sumn"] += sumn
self.stat_dict["sumr2"] += sumr2
stat_dict = {"sumr": sumr, "sumn": sumn, "sumr2": sumr2}
self.merge_input_stats(stat_dict)

def merge_input_stats(self, stat_dict):
"""Merge the statisitcs computed from compute_input_stats to obtain the self.davg and self.dstd.
Expand Down
33 changes: 8 additions & 25 deletions deepmd/tf/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def __init__(
activation_function: str = "tanh",
precision: str = "default",
uniform_seed: bool = False,
multi_task: bool = False,
**kwargs,
) -> None:
"""Constructor."""
Expand Down Expand Up @@ -172,15 +171,6 @@ def __init__(
sel_r=self.sel_r,
)
self.sub_sess = tf.Session(graph=sub_graph, config=default_tf_session_config)
self.multi_task = multi_task
if multi_task:
self.stat_dict = {
"sumr": [],
"suma": [],
"sumn": [],
"sumr2": [],
"suma2": [],
}

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -256,21 +246,14 @@ def compute_input_stats(
sumn.append(sysn)
sumr2.append(sysr2)
suma2.append(sysa2)
if not self.multi_task:
stat_dict = {
"sumr": sumr,
"suma": suma,
"sumn": sumn,
"sumr2": sumr2,
"suma2": suma2,
}
self.merge_input_stats(stat_dict)
else:
self.stat_dict["sumr"] += sumr
self.stat_dict["suma"] += suma
self.stat_dict["sumn"] += sumn
self.stat_dict["sumr2"] += sumr2
self.stat_dict["suma2"] += suma2
stat_dict = {
"sumr": sumr,
"suma": suma,
"sumn": sumn,
"sumr2": sumr2,
"suma2": suma2,
}
self.merge_input_stats(stat_dict)

def merge_input_stats(self, stat_dict):
"""Merge the statisitcs computed from compute_input_stats to obtain the self.davg and self.dstd.
Expand Down

0 comments on commit 063de8a

Please sign in to comment.