Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add args decorator for fitting and loss #2710

Merged
merged 2 commits into from
Aug 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 16 additions & 14 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,10 @@


# --- Fitting net configurations: --- #
fitting_args_plugin = ArgsPlugin()

Check warning on line 480 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L480

Added line #L480 was not covered by tests


@fitting_args_plugin.register("ener")

Check warning on line 483 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L483

Added line #L483 was not covered by tests
def fitting_ener():
doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams."
doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams."
Expand Down Expand Up @@ -542,6 +546,7 @@
]


@fitting_args_plugin.register("dos")

Check warning on line 549 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L549

Added line #L549 was not covered by tests
def fitting_dos():
doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams."
doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams."
Expand Down Expand Up @@ -584,6 +589,7 @@
]


@fitting_args_plugin.register("polar")

Check warning on line 592 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L592

Added line #L592 was not covered by tests
def fitting_polar():
doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built."
doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.'
Expand Down Expand Up @@ -635,6 +641,7 @@
# return fitting_polar()


@fitting_args_plugin.register("dipole")

Check warning on line 644 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L644

Added line #L644 was not covered by tests
def fitting_dipole():
doc_neuron = "The number of neurons in each hidden layers of the fitting net. When two hidden layers are of the same size, a skip connection is built."
doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.'
Expand Down Expand Up @@ -681,12 +688,7 @@

return Variant(
"type",
[
Argument("ener", dict, fitting_ener()),
Argument("dos", dict, fitting_dos()),
Argument("dipole", dict, fitting_dipole()),
Argument("polar", dict, fitting_polar()),
],
fitting_args_plugin.get_all_argument(),
optional=True,
default_tag="ener",
doc=doc_descrpt_type,
Expand Down Expand Up @@ -989,6 +991,10 @@
return f"The prefactor of {item} loss at the limit of the training, Should be larger than or equal to 0. i.e. the training step goes to infinity."


loss_args_plugin = ArgsPlugin()

Check warning on line 994 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L994

Added line #L994 was not covered by tests


@loss_args_plugin.register("ener")

Check warning on line 997 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L997

Added line #L997 was not covered by tests
def loss_ener():
doc_start_pref_e = start_pref("energy", abbr="e")
doc_limit_pref_e = limit_pref("energy")
Expand Down Expand Up @@ -1110,6 +1116,7 @@
]


@loss_args_plugin.register("ener_spin")

Check warning on line 1119 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L1119

Added line #L1119 was not covered by tests
def loss_ener_spin():
doc_start_pref_e = start_pref("energy")
doc_limit_pref_e = limit_pref("energy")
Expand Down Expand Up @@ -1221,6 +1228,7 @@
]


@loss_args_plugin.register("dos")

Check warning on line 1231 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L1231

Added line #L1231 was not covered by tests
def loss_dos():
doc_start_pref_dos = start_pref("Density of State (DOS)")
doc_limit_pref_dos = limit_pref("Density of State (DOS)")
Expand Down Expand Up @@ -1295,6 +1303,7 @@


# YWolfeee: Modified to support tensor type of loss args.
@loss_args_plugin.register("tensor")

Check warning on line 1306 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L1306

Added line #L1306 was not covered by tests
def loss_tensor():
# doc_global_weight = "The prefactor of the weight of global loss. It should be larger than or equal to 0. If only `pref` is provided or both are not provided, training will be global mode, i.e. the shape of 'polarizability.npy` or `dipole.npy` should be #frams x [9 or 3]."
# doc_local_weight = "The prefactor of the weight of atomic loss. It should be larger than or equal to 0. If only `pref_atomic` is provided, training will be atomic mode, i.e. the shape of `polarizability.npy` or `dipole.npy` should be #frames x ([9 or 3] x #selected atoms). If both `pref` and `pref_atomic` are provided, training will be combined mode, and atomic label should be provided as well."
Expand All @@ -1319,14 +1328,7 @@

return Variant(
"type",
[
Argument("ener", dict, loss_ener()),
Argument("dos", dict, loss_dos()),
Argument("tensor", dict, loss_tensor()),
Argument("ener_spin", dict, loss_ener_spin()),
# Argument("polar", dict, loss_tensor()),
# Argument("global_polar", dict, loss_tensor("global"))
],
loss_args_plugin.get_all_argument(),
optional=True,
default_tag="ener",
doc=doc_loss,
Expand Down