Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Atomic Convolution Model in Pytorch #3698

Merged
merged 2 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepchem/models/torch_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from deepchem.models.torch_models.readout import GroverReadout
from deepchem.models.torch_models.dtnn import DTNN, DTNNModel
from deepchem.models.torch_models.seqtoseq import SeqToSeq, SeqToSeqModel
from deepchem.models.torch_models.acnn import AtomConvModel
try:
from deepchem.models.torch_models.dmpnn import DMPNN, DMPNNModel
from deepchem.models.torch_models.gnn import GNN, GNNHead, GNNModular
Expand Down
299 changes: 299 additions & 0 deletions deepchem/models/torch_models/acnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
import torch
import numpy as np
from deepchem.models.torch_models.torch_model import TorchModel
from deepchem.models.torch_models.layers import AtomicConv
from deepchem.data import Dataset
from deepchem.models.losses import L2Loss

from typing import List, Callable, Optional, Sequence, Tuple, Iterable
from deepchem.utils.typing import OneOrMany, ActivationFn


class AtomConvModel(TorchModel):
"""An Atomic Convolutional Neural Network (ACNN) for energy score prediction.

The network follows the design of a graph convolutional network but in this case the graph is represented
as a 3D structure of the molecule. The objective of this model is to train models and predict energetic
state starting from the spatial geometry of the model [1].

References
----------
.. [1] Gomes, Joseph, et al. "Atomic convolutional networks for predicting protein-ligand binding affinity." arXiv preprint arXiv:1703.10603 (2017).

Examples
--------
>>> from deepchem.models.torch_models import AtomConvModel
>>> frag1_num_atoms = 100 # atoms for ligand
>>> frag2_num_atoms = 1200 # atoms for protein
>>> complex_num_atoms = frag1_num_atoms + frag2_num_atoms
>>> batch_size = 1
>>> # Initialize the model
>>> atomic_convnet = AtomConvModel(n_tasks=1,
batch_size=batch_size,
layer_sizes=[
10,
],
frag1_num_atoms=frag1_num_atoms,
frag2_num_atoms=frag2_num_atoms,
complex_num_atoms=complex_num_atoms)

# Creates a set of dummy features that contain the coordinate and
# neighbor-list features required by the AtomicConvModel.
>>> # Preparing the dataset
>>> features = []
>>> frag1_coords = np.random.rand(frag1_num_atoms, 3)
>>> frag1_nbr_list = {i: [] for i in range(frag1_num_atoms)}
>>> frag1_z = np.random.randint(10, size=(frag1_num_atoms))
>>> frag2_coords = np.random.rand(frag2_num_atoms, 3)
>>> frag2_nbr_list = {i: [] for i in range(frag2_num_atoms)}
>>> frag2_z = np.random.randint(10, size=(frag2_num_atoms))
>>> system_coords = np.random.rand(complex_num_atoms, 3)
>>> system_nbr_list = {i: [] for i in range(complex_num_atoms)}
>>> system_z = np.random.randint(10, size=(complex_num_atoms))
>>> features.append((frag1_coords, frag1_nbr_list, frag1_z, frag2_coords, frag2_nbr_list, frag2_z, system_coords, system_nbr_list, system_z))
>>> features = np.asarray(features)
>>> labels = np.zeros(batch_size)
>>> train = NumpyDataset(features, labels)
>>> atomic_convnet.fit(train, nb_epoch=1)
>>> preds = atomic_convnet.predict(train)
"""

def __init__(self,
n_tasks: int,
frag1_num_atoms: int = 70,
frag2_num_atoms: int = 634,
complex_num_atoms: int = 701,
max_num_neighbors: int = 12,
batch_size: int = 24,
atom_types: Sequence[float] = [
6, 7., 8., 9., 11., 12., 15., 16., 17., 20., 25., 30., 35.,
53., -1.
],
radial: Sequence[Sequence[float]] = [[
1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0,
7.5, 8.0, 8.5, 9.0, 9.5, 10.0, 10.5, 11.0, 11.5, 12.0
], [0.0, 4.0, 8.0], [0.4]],
layer_sizes=[100],
weight_init_stddevs: OneOrMany[float] = 0.02,
bias_init_consts: OneOrMany[float] = 1.0,
weight_decay_penalty: float = 0.0,
weight_decay_penalty_type: str = "l2",
dropouts: OneOrMany[float] = 0.5,
activation_fns: OneOrMany[ActivationFn] = ['relu'],
residual: bool = False,
learning_rate=0.001,
**kwargs) -> None:
"""TorchModel wrapper for ACNN

Parameters
----------
n_tasks: int
number of tasks
frag1_num_atoms: int
Number of atoms in first fragment.
frag2_num_atoms: int
Number of atoms in second fragment.
complex_num_atoms: int
Number of atoms in complex.
max_num_neighbors: int
Maximum number of neighbors possible for an atom. Recall neighbors
are spatial neighbors.
batch_size: int
Size of the batch.
atom_types: list
List of atoms recognized by model. Atoms are indicated by their
nuclear numbers.
radial: list
Radial parameters used in the atomic convolution transformation.
layer_sizes: list
the size of each dense layer in the network. The length of
this list determines the number of layers.
weight_init_stddevs: list or float
the standard deviation of the distribution to use for weight
initialization of each layer. The length of this list should
equal len(layer_sizes). Alternatively, this may be a single
value instead of a list, where the same value is used
for every layer.
bias_init_consts: list or float
the value to initialize the biases in each layer. The
length of this list should equal len(layer_sizes).
Alternatively, this may be a single value instead of a list, where the same value is used for every layer.
dropouts: list or float
the dropout probability to use for each layer. The length of this list should equal len(layer_sizes).
Alternatively, this may be a single value instead of a list, where the same value is used for every layer.
activation_fns: list or object
the Tensorflow activation function to apply to each layer. The length of this list should equal
len(layer_sizes). Alternatively, this may be a single value instead of a list, where the
same value is used for every layer.
residual: bool
Whether to use residual connections.
learning_rate: float
the learning rate to use for fitting.
"""

self.n_tasks = n_tasks
self.complex_num_atoms = complex_num_atoms
self.frag1_num_atoms = frag1_num_atoms
self.frag2_num_atoms = frag2_num_atoms
self.max_num_neighbors = max_num_neighbors
self.batch_size = batch_size
self.atom_types = atom_types

self.model = AtomicConv(n_tasks=n_tasks,
frag1_num_atoms=frag1_num_atoms,
frag2_num_atoms=frag2_num_atoms,
complex_num_atoms=complex_num_atoms,
max_num_neighbors=max_num_neighbors,
batch_size=batch_size,
atom_types=atom_types,
radial=radial,
layer_sizes=layer_sizes,
weight_init_stddevs=weight_init_stddevs,
bias_init_consts=bias_init_consts,
dropouts=dropouts,
activation_fns=activation_fns,
residual=residual,
learning_rate=learning_rate)

regularization_loss: Optional[Callable]

if weight_decay_penalty != 0:
weights = [layer.weight for layer in self.model.layers]
if weight_decay_penalty_type == 'l1':
regularization_loss = lambda: weight_decay_penalty * torch.sum( # noqa: E731
torch.stack([torch.abs(w).sum() for w in weights]))
else:
regularization_loss = lambda: weight_decay_penalty * torch.sum( # noqa: E731
torch.stack([torch.square(w).sum() for w in weights]))
else:
regularization_loss = None

loss = L2Loss()

super(AtomConvModel,
self).__init__(self.model,
loss=loss,
batch_size=batch_size,
regularization_loss=regularization_loss,
**kwargs)

def default_generator(
self,
dataset: Dataset,
epochs: int = 1,
mode: str = 'fit',
deterministic: bool = True,
pad_batches: bool = True) -> Iterable[Tuple[List, List, List]]:
"""Convert a dataset into the tensors needed for learning.

Parameters
----------
dataset: `dc.data.Dataset`
Dataset to convert
epochs: int, optional (Default 1)
Number of times to walk over `dataset`
mode: str, optional (Default 'fit')
Ignored in this implementation.
deterministic: bool, optional (Default True)
Whether the dataset should be walked in a deterministic fashion
pad_batches: bool, optional (Default True)
If true, each returned batch will have size `self.batch_size`.

Returns
-------
Iterator which walks over the batches
"""

batch_size = self.batch_size

def replace_atom_types(z):
"""Replace the atom types depending by their nuclear numbers to "-1" value if reapeated for the training loop.

Parameters
----------
z: list
Atom types learned from the model.

Returns
-------
A list of nuclear numbers with "-1" values in the repeated indexes of inverted z.
"""
np.putmask(z, np.isin(z, list(self.atom_types), invert=True), -1)
return z

for epoch in range(epochs):
for ind, (F_b, y_b, w_b, ids_b) in enumerate(
dataset.iterbatches(batch_size,
deterministic=True,
pad_batches=pad_batches)):

N = self.complex_num_atoms
N_1 = self.frag1_num_atoms
N_2 = self.frag2_num_atoms
M = self.max_num_neighbors

batch_size = F_b.shape[0]
num_features = F_b[0][0].shape[1]
frag1_X_b = np.zeros((batch_size, N_1, num_features))
for i in range(batch_size):
frag1_X_b[i] = F_b[i][0]

frag2_X_b = np.zeros((batch_size, N_2, num_features))
for i in range(batch_size):
frag2_X_b[i] = F_b[i][3]

complex_X_b = np.zeros((batch_size, N, num_features))
for i in range(batch_size):
complex_X_b[i] = F_b[i][6]

frag1_Nbrs = np.zeros((batch_size, N_1, M))
frag1_Z_b = np.zeros((batch_size, N_1))
for i in range(batch_size):
z = replace_atom_types(F_b[i][2])
frag1_Z_b[i] = z
frag1_Nbrs_Z = np.zeros((batch_size, N_1, M))
for atom in range(N_1):
for i in range(batch_size):
atom_nbrs = F_b[i][1].get(atom, "")
frag1_Nbrs[i,
atom, :len(atom_nbrs)] = np.array(atom_nbrs)
for j, atom_j in enumerate(atom_nbrs):
frag1_Nbrs_Z[i, atom, j] = frag1_Z_b[i, atom_j]

frag2_Nbrs = np.zeros((batch_size, N_2, M))
frag2_Z_b = np.zeros((batch_size, N_2))
for i in range(batch_size):
z = replace_atom_types(F_b[i][5])
frag2_Z_b[i] = z
frag2_Nbrs_Z = np.zeros((batch_size, N_2, M))
for atom in range(N_2):
for i in range(batch_size):
atom_nbrs = F_b[i][4].get(atom, "")
frag2_Nbrs[i,
atom, :len(atom_nbrs)] = np.array(atom_nbrs)
for j, atom_j in enumerate(atom_nbrs):
frag2_Nbrs_Z[i, atom, j] = frag2_Z_b[i, atom_j]

complex_Nbrs = np.zeros((batch_size, N, M))
complex_Z_b = np.zeros((batch_size, N))
for i in range(batch_size):
z = replace_atom_types(F_b[i][8])
complex_Z_b[i] = z
complex_Nbrs_Z = np.zeros((batch_size, N, M))
for atom in range(N):
for i in range(batch_size):
atom_nbrs = F_b[i][7].get(atom, "")
complex_Nbrs[i, atom, :len(atom_nbrs)] = np.array(
atom_nbrs)
for j, atom_j in enumerate(atom_nbrs):
complex_Nbrs_Z[i, atom, j] = complex_Z_b[i, atom_j]

inputs = [
frag1_X_b, frag1_Nbrs, frag1_Nbrs_Z, frag1_Z_b, frag2_X_b,
frag2_Nbrs, frag2_Nbrs_Z, frag2_Z_b, complex_X_b,
complex_Nbrs, complex_Nbrs_Z, complex_Z_b
]

y_b = np.reshape(y_b, newshape=(batch_size, 1))

yield (inputs, [y_b], [w_b])
34 changes: 26 additions & 8 deletions deepchem/models/torch_models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2779,7 +2779,7 @@ def __init__(self,
self.init = init
self.n_tasks = n_tasks

rp = [x for x in itertools.product(*radial)]
self.rp = [x for x in itertools.product(*radial)]

frag1_X = np.random.rand(self.batch_size, self.frag1_num_atoms,
3).astype(np.float32)
Expand Down Expand Up @@ -2816,17 +2816,17 @@ def __init__(self,

flattener = nn.Flatten()
self._frag1_conv = AtomicConvolution(
atom_types=self.atom_types, radial_params=rp,
atom_types=self.atom_types, radial_params=self.rp,
box_size=None)([frag1_X, frag1_nbrs, frag1_nbrs_z])
flattened1 = nn.Flatten()(self._frag1_conv)

self._frag2_conv = AtomicConvolution(
atom_types=self.atom_types, radial_params=rp,
atom_types=self.atom_types, radial_params=self.rp,
box_size=None)([frag2_X, frag2_nbrs, frag2_nbrs_z])
flattened2 = flattener(self._frag2_conv)

self._complex_conv = AtomicConvolution(
atom_types=self.atom_types, radial_params=rp,
atom_types=self.atom_types, radial_params=self.rp,
box_size=None)([complex_X, complex_nbrs, complex_nbrs_z])
flattened3 = flattener(self._complex_conv)

Expand Down Expand Up @@ -2874,21 +2874,39 @@ def forward(self, inputs: OneOrMany[torch.Tensor]):
"""
Parameters
----------
x: torch.Tensor
inputs: torch.Tensor
Input Tensor
Returns
-------
torch.Tensor
Output for each label.
"""

x = self.prev_layer[0]
x = torch.reshape(x, (-1,))
flattener = nn.Flatten()
frag1_conv = AtomicConvolution(atom_types=self.atom_types,
radial_params=self.rp,
box_size=None)(
[inputs[0], inputs[1], inputs[2]])
flattened1 = nn.Flatten()(frag1_conv)

frag2_conv = AtomicConvolution(atom_types=self.atom_types,
radial_params=self.rp,
box_size=None)(
[inputs[4], inputs[5], inputs[6]])
flattened2 = flattener(frag2_conv)

complex_conv = AtomicConvolution(atom_types=self.atom_types,
radial_params=self.rp,
box_size=None)(
[inputs[8], inputs[9], inputs[10]])
flattened3 = flattener(complex_conv)

inputs_x = torch.cat((flattened1, flattened2, flattened3), dim=1)

for layer, activation_fn, dropout in zip(self.layers,
self.activation_fns,
self.dropouts):
x = layer(x)
x = layer(inputs_x)

if dropout > 0:
x = F.dropout(x, dropout)
Expand Down
Loading
Loading