Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add DOSnet training in PT #3486

Merged
merged 31 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f1a3a0d
feat: add dos training
anyangml Mar 18, 2024
d4a3965
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
3e1f3c6
fix: precommit
anyangml Mar 18, 2024
c4769ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
429f03f
feat: add dos stat
anyangml Mar 19, 2024
91d2a8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
04c7477
fix: training test
anyangml Mar 19, 2024
4d95548
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
d71a117
Merge branch 'devel' into feat/dos-train
anyangml Mar 19, 2024
bc54b68
fix: precommit
anyangml Mar 19, 2024
850ea1d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
004cdb2
fix: UTs
anyangml Mar 19, 2024
a116235
fix: UTs
anyangml Mar 19, 2024
dbd2d29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
2ff0ae6
fix: stat
anyangml Mar 20, 2024
b7b69fd
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
366f6b4
fix: stat
anyangml Mar 20, 2024
915141c
fix: dp test
anyangml Mar 20, 2024
ed65e19
fix: test examples
anyangml Mar 20, 2024
a73d392
fix UTs
anyangml Mar 20, 2024
bf8fac2
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
3b5be19
Merge branch 'devel' into feat/dos-train
anyangml Mar 20, 2024
1630800
fix: add to test_examples
Mar 20, 2024
be076af
fix: update loss
Mar 20, 2024
36d1674
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 20, 2024
c156b02
fix: add numb_dos to jit model
anyangml Mar 20, 2024
148c196
chore: refactor
anyangml Mar 20, 2024
3ad66fb
fix: UTs
anyangml Mar 20, 2024
b751921
Merge branch 'devel' into feat/dos-train
anyangml Mar 22, 2024
2e14755
fix: loss
anyangml Mar 24, 2024
3c5e6f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@

def get_numb_dos(self) -> int:
"""Get the number of DOS."""
return 0
return self.dp.model["Default"].get_fitting_net().dim_out

Check warning on line 181 in deepmd/pt/infer/deep_eval.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/infer/deep_eval.py#L181

Added line #L181 was not covered by tests
anyangml marked this conversation as resolved.
Show resolved Hide resolved

def get_has_efield(self):
"""Check if the model has efield."""
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from .denoise import (
DenoiseLoss,
)
from .dos import (
DOSLoss,
)
from .ener import (
EnergyStdLoss,
)
Expand All @@ -21,4 +24,5 @@
"EnergySpinLoss",
"TensorLoss",
"TaskLoss",
"DOSLoss",
]
226 changes: 226 additions & 0 deletions deepmd/pt/loss/dos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
)

import torch

from deepmd.pt.loss.loss import (
TaskLoss,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.data import (
DataRequirementItem,
)


class DOSLoss(TaskLoss):
def __init__(
self,
starter_learning_rate: float,
numb_dos: int,
start_pref_dos: float = 1.00,
limit_pref_dos: float = 1.00,
start_pref_cdf: float = 1000,
limit_pref_cdf: float = 1.00,
start_pref_ados: float = 0.0,
limit_pref_ados: float = 0.0,
start_pref_acdf: float = 0.0,
limit_pref_acdf: float = 0.0,
inference=False,
**kwargs,
):
r"""Construct a loss for local and global tensors.

Parameters
----------
tensor_name : str
The name of the tensor in the model predictions to compute the loss.
tensor_size : int
The size (dimension) of the tensor.
label_name : str
The name of the tensor in the labels to compute the loss.
pref_atomic : float
The prefactor of the weight of atomic loss. It should be larger than or equal to 0.
pref : float
The prefactor of the weight of global loss. It should be larger than or equal to 0.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.starter_learning_rate = starter_learning_rate
self.numb_dos = numb_dos
self.inference = inference

self.start_pref_dos = start_pref_dos
self.limit_pref_dos = limit_pref_dos
self.start_pref_cdf = start_pref_cdf
self.limit_pref_cdf = limit_pref_cdf

self.start_pref_ados = start_pref_ados
self.limit_pref_ados = limit_pref_ados
self.start_pref_acdf = start_pref_acdf
self.limit_pref_acdf = limit_pref_acdf

assert (
self.start_pref_dos >= 0.0
and self.limit_pref_dos >= 0.0
and self.start_pref_cdf >= 0.0
and self.limit_pref_cdf >= 0.0
and self.start_pref_ados >= 0.0
and self.limit_pref_ados >= 0.0
and self.start_pref_acdf >= 0.0
and self.limit_pref_acdf >= 0.0
), "Can not assign negative weight to `pref` and `pref_atomic`"

self.has_dos = (start_pref_dos != 0.0 and limit_pref_dos != 0.0) or inference
self.has_cdf = (start_pref_cdf != 0.0 and limit_pref_cdf != 0.0) or inference
self.has_ados = (start_pref_ados != 0.0 and limit_pref_ados != 0.0) or inference
self.has_acdf = (start_pref_acdf != 0.0 and limit_pref_acdf != 0.0) or inference

assert (
self.has_dos or self.has_cdf or self.has_ados or self.has_acdf
), AssertionError("Can not assian zero weight both to `pref` and `pref_atomic`")

def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
"""Return loss on local and global tensors.

Parameters
----------
model_pred : dict[str, torch.Tensor]
Model predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.

Returns
-------
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
coef = learning_rate / self.starter_learning_rate
pref_dos = (
Fixed Show fixed Hide fixed
self.limit_pref_dos + (self.start_pref_dos - self.limit_pref_dos) * coef
)
pref_cdf = (
Fixed Show fixed Hide fixed
self.limit_pref_cdf + (self.start_pref_cdf - self.limit_pref_cdf) * coef
)
pref_ados = (
self.limit_pref_ados + (self.start_pref_ados - self.limit_pref_ados) * coef
)
pref_acdf = (
self.limit_pref_acdf + (self.start_pref_acdf - self.limit_pref_acdf) * coef
)

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if self.has_ados and "atom_dos" in model_pred and "atom_dos" in label:
local_tensor_pred_dos = model_pred["atom_dos"].reshape(
[-1, natoms, self.numb_dos]
)
local_tensor_label_dos = label["atom_dos"].reshape(
[-1, natoms, self.numb_dos]
)
diff = (local_tensor_pred_dos - local_tensor_label_dos).reshape(
[-1, self.numb_dos]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss_dos = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_local_dos_loss"] = l2_local_loss_dos.detach()
loss += pref_ados * l2_local_loss_dos
rmse_local_dos = l2_local_loss_dos.sqrt()
more_loss["rmse_local_dos"] = rmse_local_dos.detach()
if self.has_acdf and "atom_dos" in model_pred and "atom_dos" in label:
local_tensor_pred_cdf = torch.cusum(

Check warning on line 143 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L143

Added line #L143 was not covered by tests
model_pred["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
)
local_tensor_label_cdf = torch.cusum(

Check warning on line 146 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L146

Added line #L146 was not covered by tests
label["atom_dos"].reshape([-1, natoms, self.numb_dos]), dim=-1
)
diff = (local_tensor_pred_cdf - local_tensor_label_cdf).reshape(

Check warning on line 149 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L149

Added line #L149 was not covered by tests
[-1, self.numb_dos]
)
if "mask" in model_pred:
diff = diff[model_pred["mask"].reshape([-1]).bool()]
l2_local_loss_cdf = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_local_cdf_loss"] = l2_local_loss_cdf.detach()
anyangml marked this conversation as resolved.
Show resolved Hide resolved
loss += pref_acdf * l2_local_loss_cdf
rmse_local_cdf = l2_local_loss_cdf.sqrt()
more_loss["rmse_local_cdf"] = rmse_local_cdf.detach()

Check warning on line 159 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L152-L159

Added lines #L152 - L159 were not covered by tests
if self.has_dos and "dos" in model_pred and "dos" in label:
global_tensor_pred_dos = model_pred["dos"].reshape([-1, self.numb_dos])
global_tensor_label_dos = label["dos"].reshape([-1, self.numb_dos])
diff = global_tensor_pred_dos - global_tensor_label_dos
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss_dos = torch.mean(

Check warning on line 166 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L161-L166

Added lines #L161 - L166 were not covered by tests
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())

Check warning on line 169 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L169

Added line #L169 was not covered by tests
else:
atom_num = natoms
l2_global_loss_dos = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_global_dos_loss"] = l2_global_loss_dos.detach()
loss += pref_dos * l2_global_loss_dos
rmse_global_dos = l2_global_loss_dos.sqrt() / atom_num
more_loss["rmse_global_dos"] = rmse_global_dos.detach()

Check warning on line 177 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L171-L177

Added lines #L171 - L177 were not covered by tests
if self.has_cdf and "dos" in model_pred and "dos" in label:
global_tensor_pred_cdf = torch.cusum(

Check warning on line 179 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L179

Added line #L179 was not covered by tests
model_pred["dos"].reshape([-1, self.numb_dos]), dim=-1
)
global_tensor_label_cdf = torch.cusum(

Check warning on line 182 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L182

Added line #L182 was not covered by tests
label["dos"].reshape([-1, self.numb_dos]), dim=-1
)
diff = global_tensor_pred_cdf - global_tensor_label_cdf
if "mask" in model_pred:
atom_num = model_pred["mask"].sum(-1, keepdim=True)
l2_global_loss_cdf = torch.mean(

Check warning on line 188 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L185-L188

Added lines #L185 - L188 were not covered by tests
torch.sum(torch.square(diff) * atom_num, dim=0) / atom_num.sum()
)
atom_num = torch.mean(atom_num.float())

Check warning on line 191 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L191

Added line #L191 was not covered by tests
else:
atom_num = natoms
l2_global_loss_cdf = torch.mean(torch.square(diff))
if not self.inference:
more_loss["l2_global_cdf_loss"] = l2_global_loss_cdf.detach()
loss += pref_cdf * l2_global_loss_cdf
rmse_global_dos = l2_global_loss_cdf.sqrt() / atom_num
more_loss["rmse_global_cdf"] = rmse_global_dos.detach()

Check warning on line 199 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L193-L199

Added lines #L193 - L199 were not covered by tests
return loss, more_loss

@property
def label_requirement(self) -> List[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
label_requirement = []
if self.has_ados or self.has_acdf:
label_requirement.append(
DataRequirementItem(
"atom_dos",
ndof=self.numb_dos,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_dos or self.has_cdf:
label_requirement.append(

Check warning on line 217 in deepmd/pt/loss/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/dos.py#L217

Added line #L217 was not covered by tests
DataRequirementItem(
"dos",
ndof=self.numb_dos,
atomic=False,
must=False,
high_prec=False,
)
)
return label_requirement
66 changes: 66 additions & 0 deletions deepmd/pt/model/task/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import copy
import logging
from typing import (
Callable,
List,
Optional,
Union,
)

import numpy as np
import torch

from deepmd.dpmodel import (
Expand All @@ -28,6 +30,13 @@
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.out_stat import (
compute_stats_from_atomic,
compute_stats_from_redu,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -96,6 +105,63 @@
]
)

def compute_output_stats(
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
self,
merged: Union[Callable[[], List[dict]], List[dict]],
stat_file_path: Optional[DPPath] = None,
) -> None:
"""
Compute the output statistics (e.g. dos bias) for the fitting net from packed data.

Parameters
----------
merged : Union[Callable[[], List[dict]], List[dict]]
- List[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], List[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
stat_file_path : Optional[DPPath]
The path to the stat file.

"""
if stat_file_path is not None:
stat_file_path = stat_file_path / "bias_dos"

Check warning on line 130 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L130

Added line #L130 was not covered by tests
if stat_file_path is not None and stat_file_path.is_file():
bias_dos = stat_file_path.load_numpy()

Check warning on line 132 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L132

Added line #L132 was not covered by tests
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
else:
if callable(merged):
# only get data for once
sampled = merged()
else:
sampled = merged

Check warning on line 138 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L138

Added line #L138 was not covered by tests
for sys in range(len(sampled)):
nframs = sampled[sys]["atype"].shape[0]

if "atom_dos" in sampled[sys]:
bias_dos = compute_stats_from_atomic(
sampled[sys]["atom_dos"].numpy(force=True),
sampled[sys]["atype"].numpy(force=True),
)[0]
else:
sys_type_count = np.zeros(

Check warning on line 148 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L148

Added line #L148 was not covered by tests
(nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION
)
for itype in range(self.ntypes):
type_mask = sampled[sys]["atype"] == itype
sys_type_count[:, itype] = type_mask.sum(dim=1).numpy(

Check warning on line 153 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L151-L153

Added lines #L151 - L153 were not covered by tests
force=True
)
sys_bias_redu = sampled[sys]["dos"].numpy(force=True)

Check warning on line 156 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L156

Added line #L156 was not covered by tests

bias_dos = compute_stats_from_redu(

Check warning on line 158 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L158

Added line #L158 was not covered by tests
sys_bias_redu, sys_type_count, rcond=self.rcond
)[0]
if stat_file_path is not None:
stat_file_path.save_numpy(bias_dos)

Check warning on line 162 in deepmd/pt/model/task/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/dos.py#L162

Added line #L162 was not covered by tests
self.bias_dos = torch.tensor(bias_dos, device=env.DEVICE)

@classmethod
def deserialize(cls, data: dict) -> "DOSFittingNet":
data = copy.deepcopy(data)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from deepmd.pt.loss import (
DenoiseLoss,
DOSLoss,
EnergySpinLoss,
EnergyStdLoss,
TensorLoss,
Expand Down Expand Up @@ -276,7 +277,8 @@ def get_loss(loss_params, start_lr, _ntypes, _model):
return EnergyStdLoss(**loss_params)
elif loss_type == "dos":
loss_params["starter_learning_rate"] = start_lr
raise NotImplementedError()
loss_params["numb_dos"] = _model.model_output_def()["dos"].output_size
return DOSLoss(**loss_params)
elif loss_type == "ener_spin":
loss_params["starter_learning_rate"] = start_lr
return EnergySpinLoss(**loss_params)
Expand Down
8 changes: 6 additions & 2 deletions deepmd/utils/out_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def compute_stats_from_atomic(
output_std = np.zeros((ntypes, ndim))
for type_i in range(ntypes):
mask = atype == type_i
output_bias[type_i] = output[mask].mean(axis=0)
output_std[type_i] = output[mask].std(axis=0)
output_bias[type_i] = (
output[mask].mean(axis=0) if output[mask].size > 0 else np.nan
)
output_std[type_i] = (
output[mask].std(axis=0) if output[mask].size > 0 else np.nan
)
return output_bias, output_std
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Si
Binary file not shown.
Binary file not shown.