Skip to content

Commit

Permalink
add edgenetwork layer
Browse files Browse the repository at this point in the history
  • Loading branch information
riya-singh28 committed Jun 21, 2023
1 parent 3bd7bda commit 07b2f6c
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 2 deletions.
Binary file not shown.
Binary file not shown.
67 changes: 67 additions & 0 deletions deepchem/models/tests/test_layers.py
Expand Up @@ -1013,3 +1013,70 @@ def test_dtnn_embedding():
result_torch = embedding_layer_torch(torch.tensor([3, 2, 4]))
assert torch.allclose(torch.tensor(results_tf), result_torch)
assert result_torch.shape == (3, 5)


@pytest.mark.torch
def test_edge_network():
"""Test invoking the Torch equivalent of EdgeNetwork."""
# init parameters
n_pair_features = 14
n_hidden = 75 # based on weave featurizer
torch_init = 'xavier_uniform_'

# generate features for testing
mols = ["CCC"]
featurizer = dc.feat.WeaveFeaturizer()
features = featurizer.featurize(mols)
X_b = np.asarray([features[0]])
X_b = dc.data.pad_features(1, X_b)

atom_feat = []
pair_feat = []
atom_to_pair = []
start = 0
for mol in X_b:
n_atoms = mol.get_num_atoms()

# index of pair features
C0, C1 = np.meshgrid(np.arange(n_atoms), np.arange(n_atoms))
atom_to_pair.append(
np.transpose(np.array([C1.flatten() + start,
C0.flatten() + start])))
start = start + n_atoms

# atom features
atom_feat.append(mol.get_atom_features())

# pair features
pair_feat.append(
np.reshape(mol.get_pair_features(),
(n_atoms * n_atoms, n_pair_features)))

atom_features = np.concatenate(atom_feat, axis=0)
pair_features = np.concatenate(pair_feat, axis=0)
atom_to_pair_array = np.concatenate(atom_to_pair, axis=0)

# tensors for torch layer
torch_pair_features = torch.Tensor(pair_features)
torch_atom_features = torch.Tensor(atom_features)
torch_atom_to_pair = torch.Tensor(atom_to_pair_array)
torch_atom_to_pair = torch.squeeze(torch_atom_to_pair.to(torch.int64),
dim=0)

torch_inputs = [
torch_pair_features, torch_atom_features, torch_atom_to_pair
]

torch_layer = dc.models.torch_models.layers.EdgeNetwork(
n_pair_features, n_hidden, torch_init)

# assigning tensorflow layer weights to torch layer
torch_layer.W = torch.from_numpy(
np.load("deepchem/models/tests/assets/edgenetwork_weights.npy"))

torch_result = torch_layer(torch_inputs)

assert np.allclose(
np.array(torch_result),
np.load("deepchem/models/tests/assets/edgenetwork_result.npy"),
atol=1e-04)
2 changes: 1 addition & 1 deletion deepchem/models/torch_models/__init__.py
Expand Up @@ -16,7 +16,7 @@
from deepchem.models.torch_models.mat import MAT, MATModel
from deepchem.models.torch_models.megnet import MEGNetModel
from deepchem.models.torch_models.normalizing_flows_pytorch import NormalizingFlow
from deepchem.models.torch_models.layers import MultilayerPerceptron, CNNModule, CombineMeanStd, WeightedLinearCombo, AtomicConvolution, NeighborList, SetGather
from deepchem.models.torch_models.layers import MultilayerPerceptron, CNNModule, CombineMeanStd, WeightedLinearCombo, AtomicConvolution, NeighborList, SetGather, EdgeNetwork
from deepchem.models.torch_models.cnn import CNN
from deepchem.models.torch_models.attention import ScaledDotProductAttention, SelfAttention
from deepchem.models.torch_models.grover import GroverModel, GroverPretrain, GroverFinetune
Expand Down
87 changes: 86 additions & 1 deletion deepchem/models/torch_models/layers.py
Expand Up @@ -16,7 +16,7 @@
pass

from deepchem.utils.typing import OneOrMany, ActivationFn, ArrayLike
from deepchem.utils.pytorch_utils import get_activation
from deepchem.utils.pytorch_utils import get_activation, segment_sum
from torch.nn import init as initializers


Expand Down Expand Up @@ -3056,3 +3056,88 @@ def forward(self, inputs: torch.Tensor):
atom_enbeddings = torch.nn.functional.embedding(atom_number,
self.embedding_list)
return atom_enbeddings


class EdgeNetwork(nn.Module):
"""The EdgeNetwork module is a PyTorch submodule designed for message passing in graph neural networks.
Examples
--------
>>> pair_features = torch.rand((4, 2), dtype=torch.float32)
>>> atom_features = torch.rand((5, 2), dtype=torch.float32)
>>> atom_to_pair = []
>>> n_atoms = 2
>>> start = 0
>>> C0, C1 = np.meshgrid(np.arange(n_atoms), np.arange(n_atoms))
>>> atom_to_pair.append(np.transpose(np.array([C1.flatten() + start, C0.flatten() + start])))
>>> atom_to_pair = torch.Tensor(atom_to_pair)
>>> atom_to_pair = torch.squeeze(atom_to_pair.to(torch.int64), dim=0)
>>> inputs = [pair_features, atom_features, atom_to_pair]
>>> n_pair_features = 2
>>> n_hidden = 2
>>> init = 'xavier_uniform_'
>>> layer = EdgeNetwork(n_pair_features, n_hidden, init)
>>> result = layer(inputs)
>>> result.shape[1]
2
"""

def __init__(self,
n_pair_features: int = 8,
n_hidden: int = 100,
init: str = 'xavier_uniform_',
**kwargs):
"""Initalises a EdgeNetwork Layer
Parameters
----------
n_pair_features: int, optional
The length of the pair features vector.
n_hidden: int, optional
number of hidden units in the passing phase
init: str, optional
Initialization function to be used in the message passing layer.
"""

super(EdgeNetwork, self).__init__(**kwargs)
self.n_pair_features: int = n_pair_features
self.n_hidden: int = n_hidden
self.init: str = init

init_func: Callable = getattr(initializers, self.init)
self.W: torch.Tensor = init_func(
torch.empty([self.n_pair_features, self.n_hidden * self.n_hidden]))
self.b: torch.Tensor = torch.zeros((self.n_hidden * self.n_hidden,))
self.built: bool = True

def __repr__(self) -> str:
return (
f'{self.__class__.__name__}(n_pair_features:{self.n_pair_features},n_hidden:{self.n_hidden},init:{self.init})'
)

def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
"""
Parameters
----------
inputs: List[torch.Tensor]
The length of atom_to_pair should be same as n_pair_features.
Returns
-------
result: torch.Tensor
Tensor containing the mapping of the edge vector to a d × d matrix, where d denotes the dimension of the internal hidden representation of each node in the graph.
"""
pair_features: torch.Tensor
atom_features: torch.Tensor
atom_to_pair: torch.Tensor
pair_features, atom_features, atom_to_pair = inputs

A: torch.Tensor = torch.add(torch.matmul(pair_features, self.W), self.b)
A = torch.reshape(A, (-1, self.n_hidden, self.n_hidden))
out: torch.Tensor = torch.unsqueeze(atom_features[atom_to_pair[:, 1]],
dim=2)
out_squeeze: torch.Tensor = torch.squeeze(torch.matmul(A, out), dim=2)
ind: torch.Tensor = atom_to_pair[:, 0]

result: torch.Tensor = segment_sum(out_squeeze, ind)

return result
3 changes: 3 additions & 0 deletions docs/source/api_reference/layers.rst
Expand Up @@ -207,6 +207,9 @@ Torch Layers
.. autoclass:: deepchem.models.torch_models.layers.DTNNEmbedding
:members:

.. autoclass:: deepchem.models.torch_models.layers.EdgeNetwork
:members:

Grover Layers
^^^^^^^^^^^^^

Expand Down

0 comments on commit 07b2f6c

Please sign in to comment.