Skip to content

Commit

Permalink
Merge pull request #3493 from shreyasvinaya/week6
Browse files Browse the repository at this point in the history
MolGANEncoderLayer Porting
  • Loading branch information
rbharath committed Jul 28, 2023
2 parents e51353b + 3f1f481 commit 0a566d1
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 4 deletions.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion deepchem/models/torch_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from deepchem.models.torch_models.mat import MAT, MATModel
from deepchem.models.torch_models.megnet import MEGNetModel
from deepchem.models.torch_models.normalizing_flows_pytorch import NormalizingFlow
from deepchem.models.torch_models.layers import MultilayerPerceptron, CNNModule, CombineMeanStd, WeightedLinearCombo, AtomicConvolution, NeighborList, SetGather, EdgeNetwork, WeaveLayer, WeaveGather, MolGANConvolutionLayer, MolGANAggregationLayer, MolGANMultiConvolutionLayer
from deepchem.models.torch_models.layers import MultilayerPerceptron, CNNModule, CombineMeanStd, WeightedLinearCombo, AtomicConvolution, NeighborList, SetGather, EdgeNetwork, WeaveLayer, WeaveGather, MolGANConvolutionLayer, MolGANAggregationLayer, MolGANMultiConvolutionLayer, MolGANEncoderLayer
from deepchem.models.torch_models.cnn import CNN
from deepchem.models.torch_models.attention import ScaledDotProductAttention, SelfAttention
from deepchem.models.torch_models.grover import GroverModel, GroverPretrain, GroverFinetune
Expand Down
138 changes: 135 additions & 3 deletions deepchem/models/torch_models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3249,7 +3249,7 @@ def __init__(self,
activation=torch.tanh,
dropout_rate: float = 0.0,
name: str = "",
**kwargs):
prev_shape: int = 0):
"""
Initialize the layer
Expand All @@ -3263,6 +3263,8 @@ def __init__(self,
Used by dropout layer
name: string, optional (default="")
Name of the layer
prev_shape: int, optional (default=0)
Shape of the input tensor
"""

super(MolGANAggregationLayer, self).__init__()
Expand All @@ -3271,8 +3273,12 @@ def __init__(self,
self.dropout_rate: float = dropout_rate
self.name: str = name

self.d1 = nn.Linear(self.units, self.units)
self.d2 = nn.Linear(self.units, self.units)
if prev_shape:
self.d1 = nn.Linear(prev_shape, self.units)
self.d2 = nn.Linear(prev_shape, self.units)
else:
self.d1 = nn.Linear(self.units, self.units)
self.d2 = nn.Linear(self.units, self.units)
self.dropout_layer = nn.Dropout(dropout_rate)

def __repr__(self) -> str:
Expand Down Expand Up @@ -3437,6 +3443,132 @@ def forward(self, inputs: List) -> torch.Tensor:
return hidden_tensor


class MolGANEncoderLayer(nn.Module):
"""
Main learning layer used by MolGAN model.
MolGAN is a WGAN type model for generation of small molecules.
It role is to further simplify model.
This layer can be manually built by stacking graph convolution layers
followed by graph aggregation.
Example
-------
>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>> vertices = 9
>>> nodes = 5
>>> edges = 5
>>> dropout_rate = 0.0
>>> adjacency_tensor = torch.randn((1, vertices, vertices, edges))
>>> node_tensor = torch.randn((1, vertices, nodes))
>>> graph = MolGANEncoderLayer(units = [(128,64),128], dropout_rate= dropout_rate, edges=edges, nodes=nodes)([adjacency_tensor,node_tensor])
>>> dense = nn.Linear(128,128)(graph)
>>> dense = torch.tanh(dense)
>>> dense = nn.Dropout(dropout_rate)(dense)
>>> dense = nn.Linear(128,64)(dense)
>>> dense = torch.tanh(dense)
>>> dense = nn.Dropout(dropout_rate)(dense)
>>> output = nn.Linear(64,1)(dense)
References
----------
.. [1] Nicola De Cao et al. "MolGAN: An implicit generative model
for small molecular graphs", https://arxiv.org/abs/1805.11973
"""

def __init__(self,
units: List = [(128, 64), 128],
activation: Callable = torch.tanh,
dropout_rate: float = 0.0,
edges: int = 5,
nodes: int = 5,
name: str = ""):
"""
Initialize the layer
Parameters
----------
units: List, optional (default=[(128,64),128])
List of dimensions used by consecutive convolution layers.
The more values the more convolution layers invoked.
activation: function, optional (default=Tanh)
activation function used across model, default is Tanh
dropout_rate: float, optional (default=0.0)
Used by dropout layer
edges: int, optional (default=5)
Controls how many dense layers use for single convolution unit.
Typically matches number of bond types used in the molecule.
nodes: int, optional (default=5)
Number of features in node tensor
name: string, optional (default="")
Name of the layer
"""

super(MolGANEncoderLayer, self).__init__()
if len(units) != 2:
raise ValueError("units parameter must contain two values")
self.graph_convolution_units, self.auxiliary_units = units
self.activation = activation
self.dropout_rate = dropout_rate
self.edges = edges

self.multi_graph_convolution_layer = MolGANMultiConvolutionLayer(
units=self.graph_convolution_units,
nodes=nodes,
activation=self.activation,
dropout_rate=self.dropout_rate,
edges=self.edges)
self.graph_aggregation_layer = MolGANAggregationLayer(
units=self.auxiliary_units,
activation=self.activation,
dropout_rate=self.dropout_rate,
prev_shape=self.graph_convolution_units[-1] + nodes)

def __repr__(self) -> str:
"""
String representation of the layer
Returns
-------
string
String representation of the layer
"""
return f"{self.__class__.__name__}(units={self.units}, activation={self.activation}, dropout_rate={self.dropout_rate}), edges={self.edges})"

def forward(self, inputs: List) -> torch.Tensor:
"""
Invoke this layer
Parameters
----------
inputs: list
List of two input matrices, adjacency tensor and node features tensors
in one-hot encoding format.
Returns
--------
encoder tensor: tf.Tensor
Tensor that been through number of convolutions followed
by aggregation.
"""

output = self.multi_graph_convolution_layer(inputs)

node_tensor = inputs[1]

if len(inputs) > 2:
hidden_tensor = inputs[2]
annotations = torch.cat((output, hidden_tensor, node_tensor), -1)
else:
_, node_tensor = inputs
annotations = torch.cat((output, node_tensor), -1)

output = self.graph_aggregation_layer(annotations)
return output


class DTNNStep(nn.Module):
"""DTNNStep Layer for DTNN model.
Expand Down
125 changes: 125 additions & 0 deletions deepchem/models/torch_models/tests/test_molgan_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,128 @@ def test_multigraph_convolution_layer_values():
np.load('deepchem/models/tests/assets/molgan_multi_conv_layer_op.npy').
astype(np.float32))
assert torch.allclose(output, output_tensor, atol=1e-04)


@pytest.mark.torch
def test_graph_encoder_layer_shape():
from deepchem.models.torch_models.layers import MolGANEncoderLayer
vertices = 9
nodes = 5
edges = 5
first_convolution_unit = 128
second_convolution_unit = 64
aggregation_unit = 128
units = [(first_convolution_unit, second_convolution_unit),
aggregation_unit]

layer = MolGANEncoderLayer(units=units, edges=edges)
adjacency_tensor = torch.randn((1, vertices, vertices, edges))
node_tensor = torch.randn((1, vertices, nodes))
model = layer([adjacency_tensor, node_tensor])

assert model.shape == (1, aggregation_unit)
assert layer.graph_convolution_units == (first_convolution_unit,
second_convolution_unit)
assert layer.auxiliary_units == aggregation_unit
assert layer.activation == torch.tanh
assert layer.edges == 5
assert layer.dropout_rate == 0.0


@pytest.mark.torch
def test_graph_encoder_layer_values():
"""
Test to check the Values of the Graph Encoder Layer
It first loads the weights of the TF model
Then it starts transfering the weights to the torch model
1. MultiConvolution Layer
1.1 First Convolution Layer
1.2 Rest of the Convolution Layers
2. Aggregation Layer
Then it loads the input tensors and checks the output
"""
from deepchem.models.torch_models.layers import MolGANEncoderLayer
nodes = 5
edges = 5
first_convolution_unit = 128
second_convolution_unit = 64
aggregation_unit = 128
units = [(first_convolution_unit, second_convolution_unit),
aggregation_unit]

torch.manual_seed(21)
tf_weights = np.load(
'deepchem/models/tests/assets/molgan_encoder_layer_weights.npy',
allow_pickle=True).item()
torch_model_encoder = MolGANEncoderLayer(units=units,
nodes=nodes,
edges=edges,
name='layer1')

x = 12 # the starting number for the dense layers in the tf model weights
with torch.no_grad():
# Testing MultiConvolution Layer

# Testing First Convolution Layer
# dense1 layer - list of dense layers
for idx, dense in enumerate(
torch_model_encoder.multi_graph_convolution_layer.
first_convolution.dense1):
weight_name = f'layer1///dense_{idx+x}/kernel:0'
bias_name = f'layer1///dense_{idx+x}/bias:0'

dense.weight.data = torch.from_numpy(
np.transpose(tf_weights[weight_name]))
dense.bias.data = torch.from_numpy(tf_weights[bias_name])
idx += 1

# dense2 layer - single dense layer
torch_model_encoder.multi_graph_convolution_layer.first_convolution.dense2.weight.data = torch.from_numpy(
np.transpose(tf_weights[f'layer1///dense_{idx+x}/kernel:0']))
torch_model_encoder.multi_graph_convolution_layer.first_convolution.dense2.bias.data = torch.from_numpy(
tf_weights[f'layer1///dense_{idx+x}/bias:0'])
x += 5

# Testing rest of the Multi convolution layer
for idx_, layer in enumerate(
torch_model_encoder.multi_graph_convolution_layer.gcl):
# dense1 layer - list of dense layers
for idx, dense in enumerate(layer.dense1):
weight_name = f'layer1///dense_{idx+x}/kernel:0'
bias_name = f'layer1///dense_{idx+x}/bias:0'
dense.weight.data = torch.from_numpy(
np.transpose(tf_weights[weight_name]))
dense.bias.data = torch.from_numpy(tf_weights[bias_name])
x += 1

# dense2 layer - single dense layer
layer.dense2.weight.data = torch.from_numpy(
np.transpose(tf_weights[f'layer1///dense_{idx+x}/kernel:0']))
layer.dense2.bias.data = torch.from_numpy(
tf_weights[f'layer1///dense_{idx+x}/bias:0'])

# Testing Aggregation Layer
torch_model_encoder.graph_aggregation_layer.d1.weight.data = torch.from_numpy(
np.transpose(tf_weights['layer1//dense_22/kernel:0']))
torch_model_encoder.graph_aggregation_layer.d1.bias.data = torch.from_numpy(
tf_weights['layer1//dense_22/bias:0'])
torch_model_encoder.graph_aggregation_layer.d2.weight.data = torch.from_numpy(
np.transpose(tf_weights['layer1//dense_23/kernel:0']))
torch_model_encoder.graph_aggregation_layer.d2.bias.data = torch.from_numpy(
tf_weights['layer1//dense_23/bias:0'])

# Loading input tensors
adjacency_tensor = torch.from_numpy(
np.load('deepchem/models/tests/assets/molgan_adj_tensor.npy').astype(
np.float32))
node_tensor = torch.from_numpy(
np.load('deepchem/models/tests/assets/molgan_nod_tensor.npy').astype(
np.float32))

# Testing output
output = torch_model_encoder([adjacency_tensor, node_tensor])
output_tensor = torch.from_numpy(
np.load(
'deepchem/models/tests/assets/molgan_encoder_layer_op.npy').astype(
np.float32))
assert torch.allclose(output, output_tensor, atol=1e-04)
3 changes: 3 additions & 0 deletions docs/source/api_reference/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ Torch Layers
.. autoclass:: deepchem.models.torch_models.layers.MolGANMultiConvolutionLayer
:members:

.. autoclass:: deepchem.models.torch_models.layers.MolGANEncoderLayer
:members:

.. autoclass:: deepchem.models.torch_models.layers.DTNNStep
:members:

Expand Down
1 change: 1 addition & 0 deletions docs/source/api_reference/torch_layers.csv
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ WeaveGather,`ref <https://pubmed.ncbi.nlm.nih.gov/27558503/>`_, WeaveModel
MolGANConvolutionLayer,`ref <https://arxiv.org/abs/1805.11973>`_, MolGan
MolGANAggregationLayer,`ref <https://arxiv.org/abs/1805.11973>`_, MolGan
MolGANMultiConvolutionLayer,`ref <https://arxiv.org/abs/1805.11973>`_, MolGan
MolGANEncoderLayer, `ref <https://arxiv.org/abs/1805.11973>`_, MolGan
DTNNEmbedding, ref`<https://arxiv.org/abs/1609.08259>`_, DTNNModel
DTNNStep, ref`<https://arxiv.org/abs/1609.08259>`_, DTNNModel
DTNNGather, ref`<https://arxiv.org/abs/1609.08259>`_, DTNNModel

0 comments on commit 0a566d1

Please sign in to comment.