Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supervised GNNModular tasks #3343

Merged
merged 19 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions deepchem/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,69 @@ def loss(atom_vocab_task_atom_pred: torch.Tensor,
return loss


class EdgePredictionLoss(Loss):
"""
Unsupervised graph edge prediction loss.
rbharath marked this conversation as resolved.
Show resolved Hide resolved

The inputs in this loss must be a BatchGraphData object transformed by the negative_edge_sampler molecule.

Examples
--------
>>> from deepchem.models.losses import EdgePredictionLoss
>>> from deepchem.feat.graph_data import BatchGraphData, GraphData
>>> from deepchem.models.torch_models.gnn import negative_edge_sampler
>>> import torch
>>> import numpy as np
>>> emb_dim = 8
>>> num_nodes_list, num_edge_list = [3, 4, 5], [2, 4, 5]
>>> num_node_features, num_edge_features = 32, 32
>>> edge_index_list = [
... np.array([[0, 1], [1, 2]]),
... np.array([[0, 1, 2, 3], [1, 2, 0, 2]]),
... np.array([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]),
... ]
>>> graph_list = [
... GraphData(node_features=np.random.random_sample(
... (num_nodes_list[i], num_node_features)),
... edge_index=edge_index_list[i],
... edge_features=np.random.random_sample(
... (num_edge_list[i], num_edge_features)),
... node_pos_features=None) for i in range(len(num_edge_list))
... ]
>>> batched_graph = BatchGraphData(graph_list)
>>> batched_graph = batched_graph.numpy_to_torch()
>>> neg_sampled = negative_edge_sampler(batched_graph)
>>> embedding = np.random.random((sum(num_nodes_list), emb_dim))
>>> embedding = torch.from_numpy(embedding)
>>> loss_func = EdgePredictionLoss()._create_pytorch_loss()
>>> loss = loss_func(embedding, neg_sampled)

References
----------
.. [1] Hu, W. et al. Strategies for Pre-training Graph Neural Networks. Preprint at https://doi.org/10.48550/arXiv.1905.12265 (2020).
"""

def _create_pytorch_loss(self):
import torch
self.criterion = torch.nn.BCEWithLogitsLoss()

def loss(node_emb, inputs):
positive_score = torch.sum(node_emb[inputs.edge_index[0, ::2]] *
node_emb[inputs.edge_index[1, ::2]],
dim=1)
negative_score = torch.sum(node_emb[inputs.negative_edge_index[0]] *
node_emb[inputs.negative_edge_index[1]],
dim=1)

edge_pred_loss = self.criterion(
positive_score,
torch.ones_like(positive_score)) + self.criterion(
negative_score, torch.zeros_like(negative_score))
return edge_pred_loss

return loss


def _make_tf_shapes_consistent(output, labels):
"""Try to make inputs have the same shape by adding dimensions of size 1."""
import tensorflow as tf
Expand Down
66 changes: 66 additions & 0 deletions deepchem/models/tests/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,36 @@ def get_regression_dataset():
return dataset, metric


def get_multitask_regression_dataset():
featurizer = SNAPFeaturizer()
dir = os.path.dirname(os.path.abspath(__file__))

input_file = os.path.join(dir, 'assets/multitask_regression.csv')
loader = dc.data.CSVLoader(tasks=['task0', 'task1', 'task2'],
feature_field="smiles",
featurizer=featurizer)
dataset = loader.create_dataset(input_file)
metric = dc.metrics.Metric(dc.metrics.mean_absolute_error,
mode="regression")
return dataset, metric


@pytest.mark.torch
def get_multitask_classification_dataset():
featurizer = SNAPFeaturizer()
dir = os.path.dirname(os.path.abspath(__file__))

input_file = os.path.join(dir, 'assets/multitask_example.csv')
loader = dc.data.CSVLoader(tasks=['task0', 'task1', 'task2'],
feature_field="smiles",
featurizer=featurizer)
dataset = loader.create_dataset(input_file)
metric = dc.metrics.Metric(dc.metrics.roc_auc_score,
np.mean,
mode="classification")
return dataset, metric


@pytest.mark.torch
def test_GNN_edge_pred():
"""Tests the unsupervised edge prediction task"""
Expand All @@ -31,3 +61,39 @@ def test_GNN_edge_pred():
loss1 = model.fit(dataset, nb_epoch=5)
loss2 = model.fit(dataset, nb_epoch=5)
assert loss2 < loss1


@pytest.mark.torch
def test_GNN_regression():
from deepchem.models.torch_models.gnn import GNNModular

dataset, metric = get_regression_dataset()
model = GNNModular(task="regression")
model.fit(dataset, nb_epoch=100)
scores = model.evaluate(dataset, [metric])
assert scores['mean_absolute_error'] < 0.1


@pytest.mark.torch
def test_GNN_multitask_regression():
from deepchem.models.torch_models.gnn import GNNModular

dataset, metric = get_multitask_regression_dataset()
model = GNNModular(task="regression", num_tasks=3)
model.fit(dataset, nb_epoch=100)
scores = model.evaluate(dataset, [metric])
assert scores['mean_absolute_error'] < 0.1


@pytest.mark.torch
def test_GNN_multitask_classification():
rbharath marked this conversation as resolved.
Show resolved Hide resolved
from deepchem.models.torch_models.gnn import GNNModular

dataset, metric = get_multitask_classification_dataset()
model = GNNModular(task="classification", num_tasks=3)
model.fit(dataset, nb_epoch=100)
scores = model.evaluate(dataset, [metric])
assert scores['mean-roc_auc_score'] >= 0.8


test_GNN_multitask_classification()
133 changes: 86 additions & 47 deletions deepchem/models/torch_models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
from torch_geometric.nn import GINEConv, global_add_pool, global_mean_pool, global_max_pool
from torch_geometric.nn.aggr import AttentionalAggregation, Set2Set
from torch.functional import F
from deepchem.data import Dataset
from deepchem.models.losses import SoftmaxCrossEntropy, EdgePredictionLoss
from deepchem.models.torch_models import ModularTorchModel
from deepchem.feat.graph_data import BatchGraphData
from typing import Iterable, List, Tuple
from deepchem.metrics import to_one_hot

num_atom_type = 120
num_chirality_tag = 3
Expand Down Expand Up @@ -52,10 +56,10 @@ class GNN(torch.nn.Module):
>>> smiles = ["C1=CC=CC=C1", "C1=CC=CC=C1C=O", "C1=CC=CC=C1C(=O)O"]
>>> features = featurizer.featurize(smiles)
>>> batched_graph = BatchGraphData(features).numpy_to_torch(device="cuda")
>>> modular = model = GNNModular("gin", 3, 64, 1, "attention", 0, "last", "edge_pred")
>>> modular = GNNModular(emb_dim = 8, task = "edge_pred")
>>> gnnmodel = modular.gnn
>>> print(gnnmodel(batched_graph)[0].shape)
torch.Size([23, 64])
torch.Size([23, 32])

"""

Expand Down Expand Up @@ -120,45 +124,52 @@ def forward(self, data: BatchGraphData):
h_list = [h.unsqueeze_(0) for h in h_list]
node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0]

return node_representation, data
return (node_representation, data)


class GNNHead(torch.nn.Module):
"""
Forward pass for the GNN head module.
Prediction head module for the GNNModular model.

Parameters
----------
node_representation: torch.Tensor
The node representations after passing through the GNN layers.
data: BatchGraphData
The input graph data.

Returns
-------
out: torch.Tensor
The output of the GNN head module.
pool: Union[function,torch.nn.Module]
Pooling function or nn.Module to use
head: torch.nn.Module
Prediction head to use
task: str
The type of task. Must be one of "regression", "classification".
num_tasks: int
Number of tasks.
num_classes: int
Number of classes for classification.
"""

def __init__(self, pool, head):
def __init__(self, pool, head, task, num_tasks, num_classes):
super().__init__()
self.pool = pool
self.head = head
self.task = task
self.num_tasks = num_tasks
self.num_classes = num_classes

def forward(self, node_representation, data):
def forward(self, data):
"""
Forward pass for the GNN head module.

Parameters
----------
node_representation: torch.Tensor
The node representations after passing through the GNN layers.
data: BatchGraphData
The original input graph data.
data: tuple
A tuple containing the node representations and the input graph data.
node_representation is a torch.Tensor created after passing input through the GNN layers.
input_batch is the original input BatchGraphData.
"""
node_representation, input_batch = data

pooled = self.pool(node_representation, data.graph_index)
pooled = self.pool(node_representation, input_batch.graph_index)
out = self.head(pooled)
if self.task == "classification":
out = torch.reshape(out, (-1, self.num_tasks, self.num_classes))
return out


Expand Down Expand Up @@ -213,6 +224,7 @@ def __init__(self,
num_layer: int = 3,
emb_dim: int = 64,
num_tasks: int = 1,
num_classes: int = 2,
graph_pooling: str = "attention",
dropout: int = 0,
JK: str = "concat",
Expand All @@ -221,12 +233,23 @@ def __init__(self,
self.gnn_type = gnn_type
self.num_layer = num_layer
self.emb_dim = emb_dim

self.num_tasks = num_tasks
self.num_classes = num_classes
if task == "classification":
self.output_dim = num_classes * num_tasks
self.criterion = SoftmaxCrossEntropy()._create_pytorch_loss()
elif task == "regression":
self.output_dim = num_tasks
self.criterion = F.mse_loss
elif task == "edge_pred":
self.output_dim = num_tasks
self.edge_pred_loss = EdgePredictionLoss()._create_pytorch_loss()

self.graph_pooling = graph_pooling
self.dropout = dropout
self.JK = JK
self.task = task
self.criterion = torch.nn.BCEWithLogitsLoss()

self.components = self.build_components()
self.model = self.build_model()
Expand Down Expand Up @@ -300,9 +323,9 @@ def build_components(self):

if self.JK == "concat":
head = torch.nn.Linear(mult * (self.num_layer + 1) * self.emb_dim,
self.num_tasks)
self.output_dim)
else:
head = torch.nn.Linear(mult * self.emb_dim, self.num_tasks)
head = torch.nn.Linear(mult * self.emb_dim, self.output_dim)

components = {
'atom_type_embedding':
Expand All @@ -321,7 +344,8 @@ def build_components(self):
self.gnn = GNN(components['atom_type_embedding'],
components['chirality_embedding'], components['gconvs'],
components['batch_norms'], self.dropout, self.JK)
self.gnn_head = GNNHead(components['pool'], components['head'])
self.gnn_head = GNNHead(components['pool'], components['head'],
self.task, self.num_tasks, self.num_classes)
return components

def build_model(self):
Expand All @@ -335,36 +359,34 @@ def build_model(self):

if self.task == "edge_pred": # unsupervised task, does not need pred head
return self.gnn
else:
elif self.task in ("regression", "classification"):
return torch.nn.Sequential(self.gnn, self.gnn_head)
else:
raise ValueError(f"Task {self.task} is not supported.")

def loss_func(self, inputs, labels, weights):
"""
The loss function executed in the training loop, which is based on the specified task.
"""
if self.task == "edge_pred":
return self.edge_pred_loss(inputs, labels, weights)

def edge_pred_loss(self, inputs, labels, weights):
"""
The loss function for the graph edge prediction task.

The inputs in this loss must be a BatchGraphData object transformed by the NegativeEdge molecule feature utility.
"""
node_emb, _ = self.model(
inputs) # node_emb shape == [num_nodes x emb_dim]

positive_score = torch.sum(node_emb[inputs.edge_index[0, ::2]] *
node_emb[inputs.edge_index[1, ::2]],
dim=1)
negative_score = torch.sum(node_emb[inputs.negative_edge_index[0]] *
node_emb[inputs.negative_edge_index[1]],
dim=1)

loss = self.criterion(
positive_score, torch.ones_like(positive_score)) + self.criterion(
negative_score, torch.zeros_like(negative_score))
return (loss * weights[0]).mean()
node_emb, inputs = self.model(inputs)
loss = self.edge_pred_loss(node_emb, inputs)
elif self.task == "regression":
loss = self.regression_loss(inputs, labels)
elif self.task == "classification":
loss = self.classification_loss(inputs, labels)
return (loss * weights).mean()

def regression_loss(self, inputs, labels):
out = self.model(inputs)
reg_loss = self.criterion(out, labels)
return reg_loss

def classification_loss(self, inputs, labels):
out = self.model(inputs)
out = F.softmax(out, dim=2)
class_loss = self.criterion(out, labels)
return class_loss

def _prepare_batch(self, batch):
"""
Expand Down Expand Up @@ -397,6 +419,23 @@ def _prepare_batch(self, batch):

return inputs, labels, weights

def default_generator(
self,
rbharath marked this conversation as resolved.
Show resolved Hide resolved
dataset: Dataset,
epochs: int = 1,
mode: str = 'fit',
deterministic: bool = True,
pad_batches: bool = True) -> Iterable[Tuple[List, List, List]]:
for epoch in range(epochs):
for (X_b, y_b, w_b,
ids_b) in dataset.iterbatches(batch_size=self.batch_size,
deterministic=deterministic,
pad_batches=pad_batches):
if self.task == 'classification' and y_b is not None:
y_b = to_one_hot(y_b.flatten(), self.num_classes).reshape(
-1, self.num_tasks, self.num_classes)
yield ([X_b], [y_b], [w_b])


def negative_edge_sampler(data: BatchGraphData):
"""
Expand Down
3 changes: 3 additions & 0 deletions docs/source/api_reference/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ Losses
.. autoclass:: deepchem.models.losses.GroverPretrainLoss
:members:

.. autoclass:: deepchem.models.losses.EdgePredictionLoss
:members:

Optimizers
----------

Expand Down