Skip to content

Commit

Permalink
Feat: support virtual atom (#3469)
Browse files Browse the repository at this point in the history
- support virtual atoms. the atoms with type -1 will be treated as
virtual.
- the atomic contribution of virtual atoms is zero
- provide ret_dict["mask"] to indicate which atomic contribution is real
(==1) or virtual (==0)

---------

Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
  • Loading branch information
wanghan-iapcm and Han Wang committed Mar 16, 2024
1 parent 39cb4d1 commit 4b3a77b
Show file tree
Hide file tree
Showing 13 changed files with 431 additions and 91 deletions.
82 changes: 57 additions & 25 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,22 +56,19 @@ def reinit_pair_exclude(

def atomic_output_def(self) -> FittingOutputDef:
old_def = self.fitting_output_def()
if self.atom_excl is None:
return old_def
else:
old_list = list(old_def.get_data().values())
return FittingOutputDef(
old_list # noqa:RUF005
+ [
OutputVariableDef(
name="mask",
shape=[1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
)
]
)
old_list = list(old_def.get_data().values())
return FittingOutputDef(
old_list # noqa:RUF005
+ [
OutputVariableDef(
name="mask",
shape=[1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
)
]
)

def forward_common_atomic(
self,
Expand All @@ -82,31 +79,66 @@ def forward_common_atomic(
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
) -> Dict[str, np.ndarray]:
"""Common interface for atomic inference.
This method accept extended coordinates, extended atom typs, neighbor list,
and predict the atomic contribution of the fit property.
Parameters
----------
extended_coord
extended coodinates, shape: nf x (nall x 3)
extended_atype
extended atom typs, shape: nf x nall
for a type < 0 indicating the atomic is virtual.
nlist
neighbor list, shape: nf x nloc x nsel
mapping
extended to local index mapping, shape: nf x nall
fparam
frame parameters, shape: nf x dim_fparam
aparam
atomic parameter, shape: nf x nloc x dim_aparam
Returns
-------
ret_dict
dict of output atomic properties.
should implement the definition of `fitting_output_def`.
ret_dict["mask"] of shape nf x nloc will be provided.
ret_dict["mask"][ff,ii] == 1 indicating the ii-th atom of the ff-th frame is real.
ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual.
"""
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]
if self.pair_excl is not None:
pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype)
# exclude neighbors in the nlist
nlist = np.where(pair_mask == 1, nlist, -1)

ext_atom_mask = self.make_atom_mask(extended_atype)
ret_dict = self.forward_atomic(
extended_coord,
extended_atype,
np.where(ext_atom_mask, extended_atype, 0),
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].astype(np.int32)
if self.atom_excl is not None:
atom_mask = self.atom_excl.build_type_exclude_mask(atype)
for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask
atom_mask *= self.atom_excl.build_type_exclude_mask(atype)

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask

return ret_dict

Expand Down
22 changes: 22 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,28 @@ def serialize(self) -> dict:
def deserialize(cls, data: dict):
pass

def make_atom_mask(
self,
atype: t_tensor,
) -> t_tensor:
"""The atoms with type < 0 are treated as virutal atoms,
which serves as place-holders for multi-frame calculations
with different number of atoms in different frames.
Parameters
----------
atype
Atom types. >= 0 for real atoms <0 for virtual atoms.
Returns
-------
mask
True for real atoms and False for virutal atoms.
"""
# supposed to be supported by all backends
return atype >= 0

def do_grad_r(
self,
var_name: Optional[str] = None,
Expand Down
22 changes: 16 additions & 6 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

## translated from torch implemantation by chatgpt
def build_neighbor_list(
coord1: np.ndarray,
coord: np.ndarray,
atype: np.ndarray,
nloc: int,
rcut: float,
Expand All @@ -26,10 +26,11 @@ def build_neighbor_list(
Parameters
----------
coord1 : np.ndarray
coord : np.ndarray
exptended coordinates of shape [batch_size, nall x 3]
atype : np.ndarray
extended atomic types of shape [batch_size, nall]
type < 0 the atom is treat as virtual atoms.
nloc : int
number of local atoms.
rcut : float
Expand All @@ -54,11 +55,20 @@ def build_neighbor_list(
if distinguish_types==True and we have two types
|---- nsel[0] -----| |---- nsel[1] -----|
xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1
For virtual atoms all neighboring positions are filled with -1.
"""
batch_size = coord1.shape[0]
coord1 = coord1.reshape(batch_size, -1)
nall = coord1.shape[1] // 3
batch_size = coord.shape[0]
coord = coord.reshape(batch_size, -1)
nall = coord.shape[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
xmax = np.max(coord) + 2.0 * rcut
# nf x nall
is_vir = atype < 0
coord1 = np.where(is_vir[:, :, None], xmax, coord.reshape(-1, nall, 3)).reshape(
-1, nall * 3
)
if isinstance(sel, int):
sel = [sel]
nsel = sum(sel)
Expand Down Expand Up @@ -88,7 +98,7 @@ def build_neighbor_list(
axis=-1,
)
assert list(nlist.shape) == [batch_size, nloc, nsel]
nlist = np.where((rr > rcut), -1, nlist)
nlist = np.where(np.logical_or((rr > rcut), is_vir[:, :nloc, None]), -1, nlist)

if distinguish_types:
return nlist_distinguish_types(nlist, atype, sel)
Expand Down
105 changes: 80 additions & 25 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,44 @@ def reinit_pair_exclude(
else:
self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types)

# to make jit happy...
def make_atom_mask(
self,
atype: torch.Tensor,
) -> torch.Tensor:
"""The atoms with type < 0 are treated as virutal atoms,
which serves as place-holders for multi-frame calculations
with different number of atoms in different frames.
Parameters
----------
atype
Atom types. >= 0 for real atoms <0 for virtual atoms.
Returns
-------
mask
True for real atoms and False for virutal atoms.
"""
# supposed to be supported by all backends
return atype >= 0

def atomic_output_def(self) -> FittingOutputDef:
old_def = self.fitting_output_def()
if self.atom_excl is None:
return old_def
else:
old_list = list(old_def.get_data().values())
return FittingOutputDef(
old_list # noqa:RUF005
+ [
OutputVariableDef(
name="mask",
shape=[1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
)
]
)
old_list = list(old_def.get_data().values())
return FittingOutputDef(
old_list # noqa:RUF005
+ [
OutputVariableDef(
name="mask",
shape=[1],
reduciable=False,
r_differentiable=False,
c_differentiable=False,
)
]
)

def forward_common_atomic(
self,
Expand All @@ -86,6 +106,37 @@ def forward_common_atomic(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""Common interface for atomic inference.
This method accept extended coordinates, extended atom typs, neighbor list,
and predict the atomic contribution of the fit property.
Parameters
----------
extended_coord
extended coodinates, shape: nf x (nall x 3)
extended_atype
extended atom typs, shape: nf x nall
for a type < 0 indicating the atomic is virtual.
nlist
neighbor list, shape: nf x nloc x nsel
mapping
extended to local index mapping, shape: nf x nall
fparam
frame parameters, shape: nf x dim_fparam
aparam
atomic parameter, shape: nf x nloc x dim_aparam
Returns
-------
ret_dict
dict of output atomic properties.
should implement the definition of `fitting_output_def`.
ret_dict["mask"] of shape nf x nloc will be provided.
ret_dict["mask"][ff,ii] == 1 indicating the ii-th atom of the ff-th frame is real.
ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual.
"""
_, nloc, _ = nlist.shape
atype = extended_atype[:, :nloc]

Expand All @@ -94,24 +145,28 @@ def forward_common_atomic(
# exclude neighbors in the nlist
nlist = torch.where(pair_mask == 1, nlist, -1)

ext_atom_mask = self.make_atom_mask(extended_atype)
ret_dict = self.forward_atomic(
extended_coord,
extended_atype,
torch.where(ext_atom_mask, extended_atype, 0),
nlist,
mapping=mapping,
fparam=fparam,
aparam=aparam,
)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].to(torch.int32)
if self.atom_excl is not None:
atom_mask = self.atom_excl(atype)
for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).reshape(out_shape)
ret_dict["mask"] = atom_mask
atom_mask *= self.atom_excl(atype)

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], -1])
* atom_mask[:, :, None]
).view(out_shape)
ret_dict["mask"] = atom_mask

return ret_dict

Expand Down
24 changes: 18 additions & 6 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def extend_input_and_build_neighbor_list(


def build_neighbor_list(
coord1: torch.Tensor,
coord: torch.Tensor,
atype: torch.Tensor,
nloc: int,
rcut: float,
Expand All @@ -62,10 +62,11 @@ def build_neighbor_list(
Parameters
----------
coord1 : torch.Tensor
coord : torch.Tensor
exptended coordinates of shape [batch_size, nall x 3]
atype : torch.Tensor
extended atomic types of shape [batch_size, nall]
if type < 0 the atom is treat as virtual atoms.
nloc : int
number of local atoms.
rcut : float
Expand All @@ -90,11 +91,20 @@ def build_neighbor_list(
if distinguish_types==True and we have two types
|---- nsel[0] -----| |---- nsel[1] -----|
xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1
For virtual atoms all neighboring positions are filled with -1.
"""
batch_size = coord1.shape[0]
coord1 = coord1.view(batch_size, -1)
nall = coord1.shape[1] // 3
batch_size = coord.shape[0]
coord = coord.view(batch_size, -1)
nall = coord.shape[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
xmax = torch.max(coord) + 2.0 * rcut
# nf x nall
is_vir = atype < 0
coord1 = torch.where(is_vir[:, :, None], xmax, coord.view(-1, nall, 3)).view(
-1, nall * 3
)
if isinstance(sel, int):
sel = [sel]
nsel = sum(sel)
Expand Down Expand Up @@ -133,7 +143,9 @@ def build_neighbor_list(
dim=-1,
)
assert list(nlist.shape) == [batch_size, nloc, nsel]
nlist = nlist.masked_fill((rr > rcut), -1)
nlist = torch.where(
torch.logical_or((rr > rcut), is_vir[:, :nloc, None]), -1, nlist
)

if distinguish_types:
return nlist_distinguish_types(nlist, atype, sel)
Expand Down

0 comments on commit 4b3a77b

Please sign in to comment.