Skip to content

Commit

Permalink
feat: consistent type embedding (#3617)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed Mar 31, 2024
1 parent d0d5a94 commit 0be9714
Show file tree
Hide file tree
Showing 23 changed files with 558 additions and 29 deletions.
124 changes: 124 additions & 0 deletions deepmd/dpmodel/utils/type_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
Optional,
)

import numpy as np

from deepmd.dpmodel.common import (
PRECISION_DICT,
NativeOP,
)
from deepmd.dpmodel.utils.network import (
EmbeddingNet,
)
from deepmd.utils.version import (
check_version_compatibility,
)


class TypeEmbedNet(NativeOP):
r"""Type embedding network.
Parameters
----------
ntypes : int
Number of atom types
neuron : list[int]
Number of neurons in each hidden layers of the embedding net
resnet_dt
Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b)
activation_function
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision
The precision of the embedding net parameters. Supported options are |PRECISION|
trainable
If the weights of embedding net are trainable.
seed
Random seed for initializing the network parameters.
padding
Concat the zero padding to the output, as the default embedding of empty type.
"""

def __init__(
self,
*,
ntypes: int,
neuron: List[int],
resnet_dt: bool = False,
activation_function: str = "tanh",
precision: str = "default",
trainable: bool = True,
seed: Optional[int] = None,
padding: bool = False,
) -> None:
self.ntypes = ntypes
self.neuron = neuron
self.seed = seed
self.resnet_dt = resnet_dt
self.precision = precision
self.activation_function = str(activation_function)
self.trainable = trainable
self.padding = padding
self.embedding_net = EmbeddingNet(
ntypes,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)

def call(self) -> np.ndarray:
"""Compute the type embedding network."""
embed = self.embedding_net(
np.eye(self.ntypes, dtype=PRECISION_DICT[self.precision])
)
if self.padding:
embed = np.pad(embed, ((0, 1), (0, 0)), mode="constant")
return embed

@classmethod
def deserialize(cls, data: dict):
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
Model
The deserialized model
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
data_cls = data.pop("@class")
assert data_cls == "TypeEmbedNet", f"Invalid class {data_cls}"

embedding_net = EmbeddingNet.deserialize(data.pop("embedding"))
type_embedding_net = cls(**data)
type_embedding_net.embedding_net = embedding_net
return type_embedding_net

def serialize(self) -> dict:
"""Serialize the model.
Returns
-------
dict
The serialized data
"""
return {
"@class": "TypeEmbedNet",
"@version": 1,
"ntypes": self.ntypes,
"neuron": self.neuron,
"resnet_dt": self.resnet_dt,
"precision": self.precision,
"activation_function": self.activation_function,
"trainable": self.trainable,
"padding": self.padding,
"embedding": self.embedding_net.serialize(),
}
141 changes: 134 additions & 7 deletions deepmd/pt/model/network/network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
Optional,
)

Expand All @@ -8,9 +9,15 @@
import torch.nn as nn
import torch.nn.functional as F

from deepmd.pt.model.network.mlp import (
EmbeddingNet,
)
from deepmd.pt.utils import (
env,
)
from deepmd.utils.version import (
check_version_compatibility,
)

try:
from typing import (
Expand Down Expand Up @@ -552,12 +559,12 @@ class TypeEmbedNet(nn.Module):
def __init__(self, type_nums, embed_dim, bavg=0.0, stddev=1.0):
"""Construct a type embedding net."""
super().__init__()
self.embedding = nn.Embedding(
type_nums + 1,
embed_dim,
padding_idx=type_nums,
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE,
self.embedding = TypeEmbedNetConsistent(
ntypes=type_nums,
neuron=[embed_dim],
padding=True,
activation_function="Linear",
precision="default",
)
# nn.init.normal_(self.embedding.weight[:-1], mean=bavg, std=stddev)

Expand All @@ -571,7 +578,7 @@ def forward(self, atype):
type_embedding:
"""
return self.embedding(atype)
return self.embedding(atype.device)[atype]

def share_params(self, base_class, shared_level, resume=False):
"""
Expand All @@ -590,6 +597,126 @@ def share_params(self, base_class, shared_level, resume=False):
raise NotImplementedError


class TypeEmbedNetConsistent(nn.Module):
r"""Type embedding network that is consistent with other backends.
Parameters
----------
ntypes : int
Number of atom types
neuron : list[int]
Number of neurons in each hidden layers of the embedding net
resnet_dt
Time-step `dt` in the resnet construction: y = x + dt * \phi (Wx + b)
activation_function
The activation function in the embedding net. Supported options are |ACTIVATION_FN|
precision
The precision of the embedding net parameters. Supported options are |PRECISION|
trainable
If the weights of embedding net are trainable.
seed
Random seed for initializing the network parameters.
padding
Concat the zero padding to the output, as the default embedding of empty type.
"""

def __init__(
self,
*,
ntypes: int,
neuron: List[int],
resnet_dt: bool = False,
activation_function: str = "tanh",
precision: str = "default",
trainable: bool = True,
seed: Optional[int] = None,
padding: bool = False,
):
"""Construct a type embedding net."""
super().__init__()
self.ntypes = ntypes
self.neuron = neuron
self.seed = seed
self.resnet_dt = resnet_dt
self.precision = precision
self.prec = env.PRECISION_DICT[self.precision]
self.activation_function = str(activation_function)
self.trainable = trainable
self.padding = padding
# no way to pass seed?
self.embedding_net = EmbeddingNet(
ntypes,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)
for param in self.parameters():
param.requires_grad = trainable

def forward(self, device: torch.device):
"""Caulate type embedding network.
Returns
-------
type_embedding: torch.Tensor
Type embedding network.
"""
embed = self.embedding_net(
torch.eye(self.ntypes, dtype=self.prec, device=device)
)
if self.padding:
embed = torch.cat(
[embed, torch.zeros(1, embed.shape[1], dtype=self.prec, device=device)]
)
return embed

@classmethod
def deserialize(cls, data: dict):
"""Deserialize the model.
Parameters
----------
data : dict
The serialized data
Returns
-------
TypeEmbedNetConsistent
The deserialized model
"""
data = data.copy()
check_version_compatibility(data.pop("@version", 1), 1, 1)
data_cls = data.pop("@class")
assert data_cls == "TypeEmbedNet", f"Invalid class {data_cls}"

embedding_net = EmbeddingNet.deserialize(data.pop("embedding"))
type_embedding_net = cls(**data)
type_embedding_net.embedding_net = embedding_net
return type_embedding_net

def serialize(self) -> dict:
"""Serialize the model.
Returns
-------
dict
The serialized data
"""
return {
"@class": "TypeEmbedNet",
"@version": 1,
"ntypes": self.ntypes,
"neuron": self.neuron,
"resnet_dt": self.resnet_dt,
"precision": self.precision,
"activation_function": self.activation_function,
"trainable": self.trainable,
"padding": self.padding,
"embedding": self.embedding_net.serialize(),
}


@torch.jit.script
def gaussian(x, mean, std: float):
pi = 3.14159
Expand Down
11 changes: 7 additions & 4 deletions deepmd/tf/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,14 @@ def dlopen_library(module: str, filename: str):
r"share_.+/idt|"
)[:-1]

# subpatterns:
# \1: weight name
# \2: layer index
TYPE_EMBEDDING_PATTERN = str(
r"type_embed_net+/matrix_\d+|"
r"type_embed_net+/bias_\d+|"
r"type_embed_net+/idt_\d+|"
)
r"type_embed_net/(matrix)_(\d+)|"
r"type_embed_net/(bias)_(\d+)|"
r"type_embed_net/(idt)_(\d+)|"
)[:-1]

ATTENTION_LAYER_PATTERN = str(
r"attention_layer_\d+/c_query/matrix|"
Expand Down
2 changes: 2 additions & 0 deletions deepmd/tf/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ def __init__(
self.typeebd = type_embedding
elif type_embedding is not None:
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**type_embedding,
padding=self.descrpt.explicit_ntypes,
)
Expand All @@ -686,6 +687,7 @@ def __init__(
default_args_dict = {i.name: i.default for i in default_args}
default_args_dict["activation_function"] = None
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**default_args_dict,
padding=True,
)
Expand Down
4 changes: 3 additions & 1 deletion deepmd/tf/model/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,13 @@ def __init__(
dim_descrpt=self.descrpt.get_dim_out(),
)

self.ntypes = self.descrpt.get_ntypes()
# type embedding
if type_embedding is not None and isinstance(type_embedding, TypeEmbedNet):
self.typeebd = type_embedding
elif type_embedding is not None:
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**type_embedding,
padding=self.descrpt.explicit_ntypes,
)
Expand All @@ -159,6 +161,7 @@ def __init__(
default_args_dict = {i.name: i.default for i in default_args}
default_args_dict["activation_function"] = None
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**default_args_dict,
padding=True,
)
Expand All @@ -167,7 +170,6 @@ def __init__(

# descriptor
self.rcut = self.descrpt.get_rcut()
self.ntypes = self.descrpt.get_ntypes()
# fitting
self.fitting_dict = fitting_dict
self.numb_fparam_dict = {
Expand Down
3 changes: 2 additions & 1 deletion deepmd/tf/model/pairwise_dprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,13 @@ def __init__(
compress=compress,
**kwargs,
)
self.ntypes = len(type_map)
# type embedding
if isinstance(type_embedding, TypeEmbedNet):
self.typeebd = type_embedding
else:
self.typeebd = TypeEmbedNet(
ntypes=self.ntypes,
**type_embedding,
# must use se_atten, so it must be True
padding=True,
Expand All @@ -100,7 +102,6 @@ def __init__(
compress=compress,
)
add_data_requirement("aparam", 1, atomic=True, must=True, high_prec=False)
self.ntypes = len(type_map)
self.rcut = max(self.qm_model.get_rcut(), self.qmmm_model.get_rcut())

def build(
Expand Down

0 comments on commit 0be9714

Please sign in to comment.