Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add DOSnet training in PT #3486

Merged
merged 31 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f1a3a0d
feat: add dos training
anyangml Mar 18, 2024
d4a3965
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
3e1f3c6
fix: precommit
anyangml Mar 18, 2024
c4769ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
429f03f
feat: add dos stat
anyangml Mar 19, 2024
91d2a8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
04c7477
fix: training test
anyangml Mar 19, 2024
4d95548
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
d71a117
Merge branch 'devel' into feat/dos-train
anyangml Mar 19, 2024
bc54b68
fix: precommit
anyangml Mar 19, 2024
850ea1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
004cdb2
fix: UTs
anyangml Mar 19, 2024
a116235
fix: UTs
anyangml Mar 19, 2024
dbd2d29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
2ff0ae6
fix: stat
anyangml Mar 20, 2024
b7b69fd
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
366f6b4
fix: stat
anyangml Mar 20, 2024
915141c
fix: dp test
anyangml Mar 20, 2024
ed65e19
fix: test examples
anyangml Mar 20, 2024
a73d392
fix UTs
anyangml Mar 20, 2024
bf8fac2
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
3b5be19
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
1630800
fix: add to test_examples
Mar 20, 2024
be076af
fix: update loss
Mar 20, 2024
36d1674
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2024
c156b02
fix: add numb_dos to jit model
anyangml Mar 20, 2024
148c196
chore: refactor
anyangml Mar 20, 2024
3ad66fb
fix: UTs
anyangml Mar 20, 2024
b751921
Merge branch 'devel' into feat/dos-train
anyangml Mar 22, 2024
2e14755
fix: loss
anyangml Mar 24, 2024
3c5e6f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/pt/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .denoise import (
DenoiseLoss,
)
from .dos import (
DOSLoss,
)
from .ener import (
EnergyStdLoss,
)
Expand All @@ -21,4 +24,5 @@
"EnergySpinLoss",
"TensorLoss",
"TaskLoss",
"DOSLoss",
]
226 changes: 226 additions & 0 deletions deepmd/pt/loss/dos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
)

import torch

from deepmd.pt.loss.loss import (
TaskLoss,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.data import (
DataRequirementItem,
)


class DOSLoss(TaskLoss):
def __init__(
self,
starter_learning_rate: float,
numb_dos: int,
start_pref_dos: float = 1.00,
limit_pref_dos: float = 1.00,
start_pref_cdf: float = 1000,
limit_pref_cdf: float = 1.00,
start_pref_ados: float = 0.0,
limit_pref_ados: float = 0.0,
start_pref_acdf: float = 0.0,
limit_pref_acdf: float = 0.0,
inference=False,
**kwargs,
):
r"""Construct a loss for local and global tensors.

Parameters
----------
tensor_name : str
The name of the tensor in the model predictions to compute the loss.
tensor_size : int
The size (dimension) of the tensor.
label_name : str
The name of the tensor in the labels to compute the loss.
pref_atomic : float
The prefactor of the weight of atomic loss. It should be larger than or equal to 0.
pref : float
The prefactor of the weight of global loss. It should be larger than or equal to 0.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.starter_learning_rate = starter_learning_rate
self.numb_dos = numb_dos
self.inference = inference

Check warning on line 57 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L54-L57

Added lines #L54 - L57 were not covered by tests

self.start_pref_dos = start_pref_dos
self.limit_pref_dos = limit_pref_dos
self.start_pref_cdf = start_pref_cdf
self.limit_pref_cdf = limit_pref_cdf

Check warning on line 62 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L59-L62

Added lines #L59 - L62 were not covered by tests

self.start_pref_ados = start_pref_ados
self.limit_pref_ados = limit_pref_ados
self.start_pref_acdf = start_pref_acdf
self.limit_pref_acdf = limit_pref_acdf

Check warning on line 67 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L64-L67

Added lines #L64 - L67 were not covered by tests

assert (

Check warning on line 69 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L69

Added line #L69 was not covered by tests
self.start_pref_dos >= 0.0
and self.limit_pref_dos >= 0.0
and self.start_pref_cdf >= 0.0
and self.limit_pref_cdf >= 0.0
and self.start_pref_ados >= 0.0
and self.limit_pref_ados >= 0.0
and self.start_pref_acdf >= 0.0
and self.limit_pref_acdf >= 0.0
), "Can not assign negative weight to `pref` and `pref_atomic`"

self.has_dos = (start_pref_dos != 0.0 and limit_pref_dos != 0.0) or inference
self.has_cdf = (start_pref_cdf != 0.0 and limit_pref_cdf != 0.0) or inference
self.has_ados = (start_pref_ados != 0.0 and limit_pref_ados != 0.0) or inference
self.has_acdf = (start_pref_acdf != 0.0 and limit_pref_acdf != 0.0) or inference

Check warning on line 83 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L80-L83

Added lines #L80 - L83 were not covered by tests

assert (

Check warning on line 85 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L85

Added line #L85 was not covered by tests
self.has_dos or self.has_cdf or self.has_ados or self.has_acdf
), AssertionError("Can not assian zero weight both to `pref` and `pref_atomic`")

def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
"""Return loss on local and global tensors.

Parameters
----------
model_pred : dict[str, torch.Tensor]
Model predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.

Returns
-------
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
coef = learning_rate / self.starter_learning_rate
pref_dos = (

Check warning on line 109 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L108-L109

Added lines #L108 - L109 were not covered by tests
Fixed Show fixed Hide fixed
self.limit_pref_dos + (self.start_pref_dos - self.limit_pref_dos) * coef
)
pref_cdf = (

Check warning on line 112 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L112

Added line #L112 was not covered by tests
Fixed Show fixed Hide fixed
self.limit_pref_cdf + (self.start_pref_cdf - self.limit_pref_cdf) * coef
)
pref_ados = (

Check warning on line 115 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L115

Added line #L115 was not covered by tests
self.limit_pref_ados + (self.start_pref_ados - self.limit_pref_ados) * coef
)
pref_acdf = (

Check warning on line 118 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L118

Added line #L118 was not covered by tests
self.limit_pref_acdf + (self.start_pref_acdf - self.limit_pref_acdf) * coef
)

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if self.has_ados and "atom_dos" in model_pred and "atom_dos" in label:
local_tensor_pred_dos = model_pred["atom_dos"].reshape(

Check warning on line 125 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L122-L125

Added lines #L122 - L125 were not covered by tests
[-1, natoms, self.numb_dos]
)
local_tensor_label_dos = label["atom_dos"].reshape(

Check warning on line 128 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L128

Added line #L128 was not covered by tests
[-1, natoms, self.numb_dos]
)
diff = (local_tensor_pred_dos - local_tensor_label_dos).reshape(

Check warning on line 131 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L131

Added line #L131 was not covered by tests
[-1, self.numb_dos]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss_dos = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_local_dos_loss"] = l2_local_loss_dos.detach()
loss += pref_ados * l2_local_loss_dos
rmse_local_dos = l2_local_loss_dos.sqrt()
more_loss["rmse_local_dos"] = rmse_local_dos.detach()
if self.has_acdf and "atom_dos" in model_pred and "atom_dos" in label:
local_tensor_pred_cdf = torch.cusum(

Check warning on line 143 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L134-L143

Added lines #L134 - L143 were not covered by tests
model_pred["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
)
local_tensor_label_cdf = torch.cusum(

Check warning on line 146 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L146

Added line #L146 was not covered by tests
label["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
)
diff = (local_tensor_pred_cdf - local_tensor_label_cdf).reshape(

Check warning on line 149 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L149

Added line #L149 was not covered by tests
[-1, self.numb_dos]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss_cdf = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_local_cdf_loss"] = l2_local_loss_cdf.detach()
anyangml marked this conversation as resolved.
Show resolved Hide resolved
loss += pref_acdf * l2_local_loss_cdf
rmse_local_cdf = l2_local_loss_cdf.sqrt()
more_loss["rmse_local_cdf"] = rmse_local_cdf.detach()
if self.has_dos and "dos" in model_pred and "dos" in label:
global_tensor_pred_dos = model_pred["dos"].reshape([-1, self.numb_dos])
global_tensor_label_dos = label["dos"].reshape([-1, self.numb_dos])
diff = global_tensor_pred_dos - global_tensor_label_dos
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss_dos = torch.mean(

Check warning on line 166 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L152-L166

Added lines #L152 - L166 were not covered by tests
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())

Check warning on line 169 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L169

Added line #L169 was not covered by tests
else:
atom_num = natoms
l2_global_loss_dos = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_global_dos_loss"] = l2_global_loss_dos.detach()
loss += pref_dos * l2_global_loss_dos
rmse_global_dos = l2_global_loss_dos.sqrt() / atom_num
more_loss["rmse_global_dos"] = rmse_global_dos.detach()
if self.has_cdf and "dos" in model_pred and "dos" in label:
global_tensor_pred_cdf = torch.cusum(

Check warning on line 179 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L171-L179

Added lines #L171 - L179 were not covered by tests
model_pred["dos"].reshape([-1, self.numb_dos]), dim=-1
)
global_tensor_label_cdf = torch.cusum(

Check warning on line 182 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L182

Added line #L182 was not covered by tests
label["dos"].reshape([-1, self.numb_dos]), dim=-1
)
diff = global_tensor_pred_cdf - global_tensor_label_cdf
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss_cdf = torch.mean(

Check warning on line 188 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L185-L188

Added lines #L185 - L188 were not covered by tests
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())

Check warning on line 191 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L191

Added line #L191 was not covered by tests
else:
atom_num = natoms
l2_global_loss_cdf = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_global_cdf_loss"] = l2_global_loss_cdf.detach()
loss += pref_cdf * l2_global_loss_cdf
rmse_global_dos = l2_global_loss_cdf.sqrt() / atom_num
more_loss["rmse_global_cdf"] = rmse_global_dos.detach()
return loss, more_loss

Check warning on line 200 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L193-L200

Added lines #L193 - L200 were not covered by tests

@property
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
label_requirement = []
if self.has_ados or self.has_acdf:
label_requirement.append(

Check warning on line 207 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L205-L207

Added lines #L205 - L207 were not covered by tests
DataRequirementItem(
"atom_dos",
ndof=self.numb_dos,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_dos or self.has_cdf:
label_requirement.append(

Check warning on line 217 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L216-L217

Added lines #L216 - L217 were not covered by tests
DataRequirementItem(
"dos",
ndof=self.numb_dos,
atomic=False,
must=False,
high_prec=False,
)
)
return label_requirement

Check warning on line 226 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L226

Added line #L226 was not covered by tests
65 changes: 65 additions & 0 deletions deepmd/pt/model/task/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import logging
from typing import (
Callable,
List,
Optional,
Union,
Expand All @@ -28,6 +29,13 @@
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.out_stat import (
compute_stats_from_atomic,
compute_stats_from_redu,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -96,6 +104,63 @@
]
)

def compute_output_stats(
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
self,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
) -> None:
"""
Compute the output statistics (e.g. dos bias) for the fitting net from packed data.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
if stat_file_path is not None:
stat_file_path = stat_file_path / "bias_dos"
if stat_file_path is not None and stat_file_path.is_file():
bias_dos = stat_file_path.load_numpy()

Check warning on line 131 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L128-L131

Added lines #L128 - L131 were not covered by tests
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
else:
if callable(merged):

Check warning on line 133 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L133

Added line #L133 was not covered by tests
# only get data for once
sampled = merged()

Check warning on line 135 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L135

Added line #L135 was not covered by tests
else:
sampled = merged
for sys in range(len(sampled)):
nframs = sampled[sys]["atype"].shape[0]

Check warning on line 139 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L137-L139

Added lines #L137 - L139 were not covered by tests

if "atom_dos" in sampled[sys]:
sys_atom_dos = compute_stats_from_atomic(

Check warning on line 142 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L141-L142

Added lines #L141 - L142 were not covered by tests
sampled[sys]["atom_dos"].numpy(force=True),
sampled[sys]["atype"].numpy(force=True),
)[0]
else:
sys_type_count = np.zeros(

Check warning on line 147 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L147

Added line #L147 was not covered by tests
(nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION
)
for itype in range(self.ntypes):
type_mask = sampled[sys]["atype"] == itype
sys_type_count[:, itype] = type_mask.sum(dim=1).numpy(

Check warning on line 152 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L150-L152

Added lines #L150 - L152 were not covered by tests
force=True
)
sys_bias_redu = sampled[sys]["dos"].numpy(force=True)

Check warning on line 155 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L155

Added line #L155 was not covered by tests

sys_atom_dos = compute_stats_from_redu(

Check warning on line 157 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L157

Added line #L157 was not covered by tests
sys_bias_redu, sys_type_count, rcond=self.rcond
)[0]
if stat_file_path is not None:
stat_file_path.save_numpy(sys_atom_dos)
self.bias_dos = torch.tensor(sys_atom_dos, device=env.DEVICE)

Check warning on line 162 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L160-L162

Added lines #L160 - L162 were not covered by tests

@classmethod
def deserialize(cls, data: dict) -> "DOSFittingNet":
data = copy.deepcopy(data)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from deepmd.pt.loss import (
DenoiseLoss,
DOSLoss,
EnergySpinLoss,
EnergyStdLoss,
TensorLoss,
Expand Down Expand Up @@ -276,7 +277,8 @@
return EnergyStdLoss(**loss_params)
elif loss_type == "dos":
loss_params["starter_learning_rate"] = start_lr
raise NotImplementedError()
loss_params["numb_dos"] = _model.model_output_def()["dos"].output_size
return DOSLoss(**loss_params)

Check warning on line 281 in deepmd/pt/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/train/training.py#L280-L281

Added lines #L280 - L281 were not covered by tests
elif loss_type == "ener_spin":
loss_params["starter_learning_rate"] = start_lr
return EnergySpinLoss(**loss_params)
Expand Down