diff --git a/deepchem/models/tests/assets/edgenetwork_result.npy b/deepchem/models/tests/assets/edgenetwork_result.npy new file mode 100644 index 0000000000..e6115f2063 Binary files /dev/null and b/deepchem/models/tests/assets/edgenetwork_result.npy differ diff --git a/deepchem/models/tests/assets/edgenetwork_weights.npy b/deepchem/models/tests/assets/edgenetwork_weights.npy new file mode 100644 index 0000000000..e6699c9bba Binary files /dev/null and b/deepchem/models/tests/assets/edgenetwork_weights.npy differ diff --git a/deepchem/models/tests/test_layers.py b/deepchem/models/tests/test_layers.py index 6a477e30a0..c9f016e49a 100644 --- a/deepchem/models/tests/test_layers.py +++ b/deepchem/models/tests/test_layers.py @@ -1013,3 +1013,70 @@ def test_dtnn_embedding(): result_torch = embedding_layer_torch(torch.tensor([3, 2, 4])) assert torch.allclose(torch.tensor(results_tf), result_torch) assert result_torch.shape == (3, 5) + + +@pytest.mark.torch +def test_edge_network(): + """Test invoking the Torch equivalent of EdgeNetwork.""" + # init parameters + n_pair_features = 14 + n_hidden = 75 # based on weave featurizer + torch_init = 'xavier_uniform_' + + # generate features for testing + mols = ["CCC"] + featurizer = dc.feat.WeaveFeaturizer() + features = featurizer.featurize(mols) + X_b = np.asarray([features[0]]) + X_b = dc.data.pad_features(1, X_b) + + atom_feat = [] + pair_feat = [] + atom_to_pair = [] + start = 0 + for mol in X_b: + n_atoms = mol.get_num_atoms() + + # index of pair features + C0, C1 = np.meshgrid(np.arange(n_atoms), np.arange(n_atoms)) + atom_to_pair.append( + np.transpose(np.array([C1.flatten() + start, + C0.flatten() + start]))) + start = start + n_atoms + + # atom features + atom_feat.append(mol.get_atom_features()) + + # pair features + pair_feat.append( + np.reshape(mol.get_pair_features(), + (n_atoms * n_atoms, n_pair_features))) + + atom_features = np.concatenate(atom_feat, axis=0) + pair_features = np.concatenate(pair_feat, axis=0) + atom_to_pair_array = np.concatenate(atom_to_pair, axis=0) + + # tensors for torch layer + torch_pair_features = torch.Tensor(pair_features) + torch_atom_features = torch.Tensor(atom_features) + torch_atom_to_pair = torch.Tensor(atom_to_pair_array) + torch_atom_to_pair = torch.squeeze(torch_atom_to_pair.to(torch.int64), + dim=0) + + torch_inputs = [ + torch_pair_features, torch_atom_features, torch_atom_to_pair + ] + + torch_layer = dc.models.torch_models.layers.EdgeNetwork( + n_pair_features, n_hidden, torch_init) + + # assigning tensorflow layer weights to torch layer + torch_layer.W = torch.from_numpy( + np.load("deepchem/models/tests/assets/edgenetwork_weights.npy")) + + torch_result = torch_layer(torch_inputs) + + assert np.allclose( + np.array(torch_result), + np.load("deepchem/models/tests/assets/edgenetwork_result.npy"), + atol=1e-04) diff --git a/deepchem/models/torch_models/__init__.py b/deepchem/models/torch_models/__init__.py index 0a082e433a..36a8665d0a 100644 --- a/deepchem/models/torch_models/__init__.py +++ b/deepchem/models/torch_models/__init__.py @@ -16,7 +16,7 @@ from deepchem.models.torch_models.mat import MAT, MATModel from deepchem.models.torch_models.megnet import MEGNetModel from deepchem.models.torch_models.normalizing_flows_pytorch import NormalizingFlow -from deepchem.models.torch_models.layers import MultilayerPerceptron, CNNModule, CombineMeanStd, WeightedLinearCombo, AtomicConvolution, NeighborList, SetGather +from deepchem.models.torch_models.layers import MultilayerPerceptron, CNNModule, CombineMeanStd, WeightedLinearCombo, AtomicConvolution, NeighborList, SetGather, EdgeNetwork from deepchem.models.torch_models.cnn import CNN from deepchem.models.torch_models.attention import ScaledDotProductAttention, SelfAttention from deepchem.models.torch_models.grover import GroverModel, GroverPretrain, GroverFinetune diff --git a/deepchem/models/torch_models/layers.py b/deepchem/models/torch_models/layers.py index bfa77fa598..1c56fb97e0 100644 --- a/deepchem/models/torch_models/layers.py +++ b/deepchem/models/torch_models/layers.py @@ -16,7 +16,7 @@ pass from deepchem.utils.typing import OneOrMany, ActivationFn, ArrayLike -from deepchem.utils.pytorch_utils import get_activation +from deepchem.utils.pytorch_utils import get_activation, segment_sum from torch.nn import init as initializers @@ -3056,3 +3056,88 @@ def forward(self, inputs: torch.Tensor): atom_enbeddings = torch.nn.functional.embedding(atom_number, self.embedding_list) return atom_enbeddings + + +class EdgeNetwork(nn.Module): + """The EdgeNetwork module is a PyTorch submodule designed for message passing in graph neural networks. + + Examples + -------- + >>> pair_features = torch.rand((4, 2), dtype=torch.float32) + >>> atom_features = torch.rand((5, 2), dtype=torch.float32) + >>> atom_to_pair = [] + >>> n_atoms = 2 + >>> start = 0 + >>> C0, C1 = np.meshgrid(np.arange(n_atoms), np.arange(n_atoms)) + >>> atom_to_pair.append(np.transpose(np.array([C1.flatten() + start, C0.flatten() + start]))) + >>> atom_to_pair = torch.Tensor(atom_to_pair) + >>> atom_to_pair = torch.squeeze(atom_to_pair.to(torch.int64), dim=0) + >>> inputs = [pair_features, atom_features, atom_to_pair] + >>> n_pair_features = 2 + >>> n_hidden = 2 + >>> init = 'xavier_uniform_' + >>> layer = EdgeNetwork(n_pair_features, n_hidden, init) + >>> result = layer(inputs) + >>> result.shape[1] + 2 + """ + + def __init__(self, + n_pair_features: int = 8, + n_hidden: int = 100, + init: str = 'xavier_uniform_', + **kwargs): + """Initalises a EdgeNetwork Layer + + Parameters + ---------- + n_pair_features: int, optional + The length of the pair features vector. + n_hidden: int, optional + number of hidden units in the passing phase + init: str, optional + Initialization function to be used in the message passing layer. + """ + + super(EdgeNetwork, self).__init__(**kwargs) + self.n_pair_features: int = n_pair_features + self.n_hidden: int = n_hidden + self.init: str = init + + init_func: Callable = getattr(initializers, self.init) + self.W: torch.Tensor = init_func( + torch.empty([self.n_pair_features, self.n_hidden * self.n_hidden])) + self.b: torch.Tensor = torch.zeros((self.n_hidden * self.n_hidden,)) + self.built: bool = True + + def __repr__(self) -> str: + return ( + f'{self.__class__.__name__}(n_pair_features:{self.n_pair_features},n_hidden:{self.n_hidden},init:{self.init})' + ) + + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: + """ + Parameters + ---------- + inputs: List[torch.Tensor] + The length of atom_to_pair should be same as n_pair_features. + Returns + ------- + result: torch.Tensor + Tensor containing the mapping of the edge vector to a d × d matrix, where d denotes the dimension of the internal hidden representation of each node in the graph. + """ + pair_features: torch.Tensor + atom_features: torch.Tensor + atom_to_pair: torch.Tensor + pair_features, atom_features, atom_to_pair = inputs + + A: torch.Tensor = torch.add(torch.matmul(pair_features, self.W), self.b) + A = torch.reshape(A, (-1, self.n_hidden, self.n_hidden)) + out: torch.Tensor = torch.unsqueeze(atom_features[atom_to_pair[:, 1]], + dim=2) + out_squeeze: torch.Tensor = torch.squeeze(torch.matmul(A, out), dim=2) + ind: torch.Tensor = atom_to_pair[:, 0] + + result: torch.Tensor = segment_sum(out_squeeze, ind) + + return result diff --git a/docs/source/api_reference/layers.rst b/docs/source/api_reference/layers.rst index 7ca27bda6b..d739708377 100644 --- a/docs/source/api_reference/layers.rst +++ b/docs/source/api_reference/layers.rst @@ -207,6 +207,9 @@ Torch Layers .. autoclass:: deepchem.models.torch_models.layers.DTNNEmbedding :members: +.. autoclass:: deepchem.models.torch_models.layers.EdgeNetwork + :members: + Grover Layers ^^^^^^^^^^^^^