Skip to content

Commit

Permalink
add args decorator for fitting and loss (#2710)
Browse files Browse the repository at this point in the history
Add `fitting_args_plugin` and `loss_args_plugin` into
`deepmd.utils.argcheck`. With these decorators, new parameters for
fitting and loss can be defined in the external package.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ChiahsinChu and pre-commit-ci[bot] committed Aug 5, 2023
1 parent 4fa54ec commit 9391e34
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,10 @@ def descrpt_variant_type_args(exclude_hybrid: bool = False) -> Variant:


# --- Fitting net configurations: --- #
fitting_args_plugin = ArgsPlugin()


@fitting_args_plugin.register("ener")
def fitting_ener():
doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams."
doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams."
Expand Down Expand Up @@ -542,6 +546,7 @@ def fitting_ener():
]


@fitting_args_plugin.register("dos")
def fitting_dos():
doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams."
doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams."
Expand Down Expand Up @@ -584,6 +589,7 @@ def fitting_dos():
]


@fitting_args_plugin.register("polar")
def fitting_polar():
doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built."
doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.'
Expand Down Expand Up @@ -635,6 +641,7 @@ def fitting_polar():
# return fitting_polar()


@fitting_args_plugin.register("dipole")
def fitting_dipole():
doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built."
doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.'
Expand Down Expand Up @@ -681,12 +688,7 @@ def fitting_variant_type_args():

return Variant(
"type",
[
Argument("ener", dict, fitting_ener()),
Argument("dos", dict, fitting_dos()),
Argument("dipole", dict, fitting_dipole()),
Argument("polar", dict, fitting_polar()),
],
fitting_args_plugin.get_all_argument(),
optional=True,
default_tag="ener",
doc=doc_descrpt_type,
Expand Down Expand Up @@ -989,6 +991,10 @@ def limit_pref(item):
return f"The prefactor of {item} loss at the limit of the training, Should be larger than or equal to 0. i.e. the training step goes to infinity."


loss_args_plugin = ArgsPlugin()


@loss_args_plugin.register("ener")
def loss_ener():
doc_start_pref_e = start_pref("energy", abbr="e")
doc_limit_pref_e = limit_pref("energy")
Expand Down Expand Up @@ -1110,6 +1116,7 @@ def loss_ener():
]


@loss_args_plugin.register("ener_spin")
def loss_ener_spin():
doc_start_pref_e = start_pref("energy")
doc_limit_pref_e = limit_pref("energy")
Expand Down Expand Up @@ -1221,6 +1228,7 @@ def loss_ener_spin():
]


@loss_args_plugin.register("dos")
def loss_dos():
doc_start_pref_dos = start_pref("Density of State (DOS)")
doc_limit_pref_dos = limit_pref("Density of State (DOS)")
Expand Down Expand Up @@ -1295,6 +1303,7 @@ def loss_dos():


# YWolfeee: Modified to support tensor type of loss args.
@loss_args_plugin.register("tensor")
def loss_tensor():
# doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If only `pref` is provided or both are not provided, training will be global mode, i.e. the shape of 'polarizability.npy` or `dipole.npy` should be #frams x [9 or 3]."
# doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If only `pref_atomic` is provided, training will be atomic mode, i.e. the shape of `polarizability.npy` or `dipole.npy` should be #frames x ([9 or 3] x #selected atoms). If both `pref` and `pref_atomic` are provided, training will be combined mode, and atomic label should be provided as well."
Expand All @@ -1319,14 +1328,7 @@ def loss_variant_type_args():

return Variant(
"type",
[
Argument("ener", dict, loss_ener()),
Argument("dos", dict, loss_dos()),
Argument("tensor", dict, loss_tensor()),
Argument("ener_spin", dict, loss_ener_spin()),
# Argument("polar", dict, loss_tensor()),
# Argument("global_polar", dict, loss_tensor("global"))
],
loss_args_plugin.get_all_argument(),
optional=True,
default_tag="ener",
doc=doc_loss,
Expand Down

0 comments on commit 9391e34

Please sign in to comment.