Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge compute_output_stat #3310

Merged
merged 12 commits into from
Feb 28, 2024
56 changes: 0 additions & 56 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,62 +35,6 @@
log = logging.getLogger(__name__)


class Descriptor(torch.nn.Module, BaseDescriptor):
"""The descriptor.
Given the atomic coordinates, atomic types and neighbor list,
calculate the descriptor.
"""

__plugins = Plugin()
local_cluster = False

@staticmethod
def register(key: str) -> Callable:
"""Register a descriptor plugin.

Parameters
----------
key : str
the key of a descriptor

Returns
-------
Descriptor
the registered descriptor

Examples
--------
>>> @Descriptor.register("some_descrpt")
class SomeDescript(Descriptor):
pass
"""
return Descriptor.__plugins.register(key)

@classmethod
def get_data_process_key(cls, config):
"""
Get the keys for the data preprocess.
Usually need the information of rcut and sel.
TODO Need to be deprecated when the dataloader has been cleaned up.
"""
if cls is not Descriptor:
raise NotImplementedError("get_data_process_key is not implemented!")
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config)

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
try:
descrpt_type = kwargs["type"]
except KeyError:
raise KeyError("the type of descriptor should be set by `type`")
if descrpt_type in Descriptor.__plugins.plugins:
cls = Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)
return super().__new__(cls)


class DescriptorBlock(torch.nn.Module, ABC):
"""The building block of descriptor.
Given the input descriptor, provide with the atomic coordinates,
Expand Down
11 changes: 0 additions & 11 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,6 @@ def dim_emb(self):
def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
return self.se_atten.compute_input_stats(merged, path)

@classmethod
def get_data_process_key(cls, config):
"""
Get the keys for the data preprocess.
Usually need the information of rcut and sel.
TODO Need to be deprecated when the dataloader has been cleaned up.
"""
descrpt_type = config["type"]
assert descrpt_type in ["dpa1", "se_atten"]
return {"sel": config["sel"], "rcut": config["rcut"]}

def serialize(self) -> dict:
"""Serialize the obj to dict."""
raise NotImplementedError
Expand Down
14 changes: 0 additions & 14 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,20 +306,6 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
]
descrpt.compute_input_stats(merged_tmp)

@classmethod
def get_data_process_key(cls, config):
"""
Get the keys for the data preprocess.
Usually need the information of rcut and sel.
TODO Need to be deprecated when the dataloader has been cleaned up.
"""
descrpt_type = config["type"]
assert descrpt_type in ["dpa2"]
return {
"sel": [config["repinit_nsel"], config["repformer_nsel"]],
"rcut": [config["repinit_rcut"], config["repformer_rcut"]],
}

def serialize(self) -> dict:
"""Serialize the obj to dict."""
raise NotImplementedError
Expand Down
11 changes: 0 additions & 11 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,6 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
"""Update mean and stddev for descriptor elements."""
return self.sea.compute_input_stats(merged, path)

@classmethod
def get_data_process_key(cls, config):
"""
Get the keys for the data preprocess.
Usually need the information of rcut and sel.
TODO Need to be deprecated when the dataloader has been cleaned up.
"""
descrpt_type = config["type"]
assert descrpt_type in ["se_e2_a"]
return {"sel": config["sel"], "rcut": config["rcut"]}

def forward(
self,
coord_ext: torch.Tensor,
Expand Down
14 changes: 8 additions & 6 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@
from deepmd.pt.utils.env import (
DEFAULT_PRECISION,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)
from deepmd.utils.out_stat import (
compute_output_stat,
compute_stats_from_redu,
)
from deepmd.utils.path import (
DPPath,
Expand Down Expand Up @@ -105,8 +108,7 @@
**kwargs,
):
self.dim_out = dim_out
# TODO: atom_ener
self.atom_ener = None
self.atom_ener = atom_ener
super().__init__(
var_name=var_name,
ntypes=ntypes,
Expand Down Expand Up @@ -149,16 +151,16 @@
bias_atom_e = stat_file_path.load_numpy()
else:
# shape: (nframes, ndim)
merged_energy = torch.cat(energy).detach().cpu().numpy()
merged_energy = to_numpy_array(torch.cat(energy))
# shape: (nframes, ntypes)
merged_natoms = torch.cat(input_natoms)[:, 2:].detach().cpu().numpy()
merged_natoms = to_numpy_array(torch.cat(input_natoms)[:, 2:])
if self.atom_ener is not None and len(self.atom_ener) > 0:
assigned_atom_ener = np.array(

Check warning on line 158 in deepmd/pt/model/task/ener.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/ener.py#L158

Added line #L158 was not covered by tests
[ee if ee is not None else np.nan for ee in self.atom_ener]
)
else:
assigned_atom_ener = None
bias_atom_e = compute_output_stat(
bias_atom_e, _ = compute_stats_from_redu(
merged_energy,
merged_natoms,
assigned_bias=assigned_atom_ener,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def iter(
Parameters
----------
data : List[Dict[str, torch.Tensor]]
The environment matrix.
The data.

Yields
------
Expand Down
4 changes: 2 additions & 2 deletions deepmd/tf/fit/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
one_layer_rand_seed_shift,
)
from deepmd.utils.out_stat import (
compute_output_stat,
compute_stats_from_redu,
)

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -228,7 +228,7 @@
sys_tynatom = np.reshape(sys_tynatom, [nsys, -1])
sys_tynatom = sys_tynatom[:, 2:]

dos_shift = compute_output_stat(
dos_shift, _ = compute_stats_from_redu(

Check warning on line 231 in deepmd/tf/fit/dos.py

View check run for this annotation

Codecov / codecov/patch

deepmd/tf/fit/dos.py#L231

Added line #L231 was not covered by tests
sys_dos,
sys_tynatom,
rcond=rcond,
Expand Down
8 changes: 4 additions & 4 deletions deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
Spin,
)
from deepmd.utils.out_stat import (
compute_output_stat,
compute_stats_from_redu,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -299,13 +299,13 @@ def _compute_output_stats(self, all_stat, rcond=1e-3, mixed_type=False):
)
else:
assigned_atom_ener = None
energy_shift = compute_output_stat(
energy_shift, _ = compute_stats_from_redu(
sys_ener.reshape(-1, 1),
sys_tynatom,
assigned_bias=assigned_atom_ener,
rcond=rcond,
).ravel()
return energy_shift
)
return energy_shift.ravel()

def compute_input_stats(self, all_stat: dict, protection: float = 1e-2) -> None:
"""Compute the input statistics.
Expand Down
6 changes: 2 additions & 4 deletions deepmd/tf/fit/polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,14 @@ def get_out_size(self) -> int:
"""Get the output size. Should be 9."""
return 9

def compute_input_stats(self, all_stat, protection=1e-2):
"""Compute the input statistics.
def compute_output_stats(self, all_stat):
"""Compute the output statistics.

Parameters
----------
all_stat
Dictionary of inputs.
can be prepared by model.make_stat_input
protection
Divided-by-zero protection
"""
if "polarizability" not in all_stat.keys():
self.avgeig = np.zeros([9])
Expand Down
5 changes: 3 additions & 2 deletions deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
DeepmdData,
)
from deepmd.utils.out_stat import (
compute_output_stat,
compute_stats_from_redu,
)

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -251,9 +251,10 @@ def compute_energy_shift(self, rcond=None, key="energy"):
sys_tynatom = np.array(self.natoms_vec, dtype=GLOBAL_NP_FLOAT_PRECISION)
sys_tynatom = np.reshape(sys_tynatom, [self.nsystems, -1])
sys_tynatom = sys_tynatom[:, 2:]
energy_shift = compute_output_stat(
energy_shift, _ = compute_stats_from_redu(
sys_ener.reshape(-1, 1),
sys_tynatom,
rcond=rcond,
)
return energy_shift.ravel()

Expand Down
55 changes: 51 additions & 4 deletions deepmd/utils/out_stat.py
njzjz marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@
"""Output statistics."""
from typing import (
Optional,
Tuple,
)

import numpy as np


def compute_output_stat(
def compute_stats_from_redu(
output_redu: np.ndarray,
natoms: np.ndarray,
assigned_bias: Optional[np.ndarray] = None,
rcond: Optional[float] = None,
) -> np.ndarray:
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute the output statistics.

Given the reduced output value and the number of atoms for each atom,
compute the least-squares solution as the atomic output bais.
compute the least-squares solution as the atomic output bais and std.

Parameters
----------
Expand All @@ -34,6 +35,8 @@ def compute_output_stat(
-------
np.ndarray
The computed output bias, shape is [ntypes, ndim].
np.ndarray
The computed output std, shape is [ntypes, ndim].
"""
output_redu = np.array(output_redu)
natoms = np.array(natoms)
Expand Down Expand Up @@ -67,4 +70,48 @@ def compute_output_stat(
if assigned_bias is not None:
# add back assigned atom; this might not be required
computed_output_bias[assigned_bias_atom_mask] = assigned_bias_masked
return computed_output_bias
# rest_redu: nframes, ndim
rest_redu = output_redu - np.einsum("ij,jk->ik", natoms, computed_output_bias)
output_std = rest_redu.std(axis=0)
return computed_output_bias, output_std


def compute_stats_from_atomic(
output: np.ndarray,
atype: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute the output statistics.

Given the output value and the type of atoms,
compute the atomic output bais and std.

Parameters
----------
output
The output value, shape is [nframes, nloc, ndim].
atype
The type of atoms, shape is [nframes, nloc].

Returns
-------
np.ndarray
The computed output bias, shape is [ntypes, ndim].
np.ndarray
The computed output std, shape is [ntypes, ndim].
"""
output = np.array(output)
atype = np.array(atype)
# check shape
assert output.ndim == 3
assert atype.ndim == 2
assert output.shape[:2] == atype.shape
# compute output bias
nframes, nloc, ndim = output.shape
ntypes = atype.max() + 1
output_bias = np.zeros((ntypes, ndim))
output_std = np.zeros((ntypes, ndim))
for type_i in range(ntypes):
mask = atype == type_i
output_bias[type_i] = output[mask].mean(axis=0)
output_std[type_i] = output[mask].std(axis=0)
return output_bias, output_std