Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supervised GNNModular tasks #3343

Merged
merged 19 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions deepchem/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,73 @@ def loss(atom_vocab_task_atom_pred: torch.Tensor,
return loss


class EdgePredictionLoss(Loss):
"""
EdgePredictionLoss is an unsupervised graph edge prediction loss function that calculates the loss based on the similarity between node embeddings for positive and negative edge pairs. This loss function is designed for graph neural networks and is particularly useful for pre-training tasks.

This loss function encourages the model to learn node embeddings that can effectively distinguish between true edges (positive samples) and false edges (negative samples) in the graph.

The loss is computed by comparing the similarity scores (dot product) of node embeddings for positive and negative edge pairs. The goal is to maximize the similarity for positive pairs and minimize it for negative pairs.

To use this loss function, the input must be a BatchGraphData object transformed by the negative_edge_sampler. The loss function takes the node embeddings and the input graph data (with positive and negative edge pairs) as inputs and returns the edge prediction loss.

Examples
--------
>>> from deepchem.models.losses import EdgePredictionLoss
>>> from deepchem.feat.graph_data import BatchGraphData, GraphData
>>> from deepchem.models.torch_models.gnn import negative_edge_sampler
>>> import torch
>>> import numpy as np
>>> emb_dim = 8
>>> num_nodes_list, num_edge_list = [3, 4, 5], [2, 4, 5]
>>> num_node_features, num_edge_features = 32, 32
>>> edge_index_list = [
... np.array([[0, 1], [1, 2]]),
... np.array([[0, 1, 2, 3], [1, 2, 0, 2]]),
... np.array([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]]),
... ]
>>> graph_list = [
... GraphData(node_features=np.random.random_sample(
... (num_nodes_list[i], num_node_features)),
... edge_index=edge_index_list[i],
... edge_features=np.random.random_sample(
... (num_edge_list[i], num_edge_features)),
... node_pos_features=None) for i in range(len(num_edge_list))
... ]
>>> batched_graph = BatchGraphData(graph_list)
>>> batched_graph = batched_graph.numpy_to_torch()
>>> neg_sampled = negative_edge_sampler(batched_graph)
>>> embedding = np.random.random((sum(num_nodes_list), emb_dim))
>>> embedding = torch.from_numpy(embedding)
>>> loss_func = EdgePredictionLoss()._create_pytorch_loss()
>>> loss = loss_func(embedding, neg_sampled)

References
----------
.. [1] Hu, W. et al. Strategies for Pre-training Graph Neural Networks. Preprint at https://doi.org/10.48550/arXiv.1905.12265 (2020).
"""

def _create_pytorch_loss(self):
import torch
self.criterion = torch.nn.BCEWithLogitsLoss()

def loss(node_emb, inputs):
positive_score = torch.sum(node_emb[inputs.edge_index[0, ::2]] *
node_emb[inputs.edge_index[1, ::2]],
dim=1)
negative_score = torch.sum(node_emb[inputs.negative_edge_index[0]] *
node_emb[inputs.negative_edge_index[1]],
dim=1)

edge_pred_loss = self.criterion(
positive_score,
torch.ones_like(positive_score)) + self.criterion(
negative_score, torch.zeros_like(negative_score))
return edge_pred_loss

return loss


def _make_tf_shapes_consistent(output, labels):
"""Try to make inputs have the same shape by adding dimensions of size 1."""
import tensorflow as tf
Expand Down
79 changes: 71 additions & 8 deletions deepchem/models/tests/test_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,34 @@ def compare_weights(key, model1, model2):
model2.components[key].weight)).item()


def get_multitask_regression_dataset():
featurizer = SNAPFeaturizer()
dir = os.path.dirname(os.path.abspath(__file__))

input_file = os.path.join(dir, 'assets/multitask_regression.csv')
loader = dc.data.CSVLoader(tasks=['task0', 'task1', 'task2'],
feature_field="smiles",
featurizer=featurizer)
dataset = loader.create_dataset(input_file)
metric = dc.metrics.Metric(dc.metrics.mean_absolute_error,
mode="regression")
return dataset, metric


@pytest.mark.torch
def test_GNN_edge_pred():
"""Tests the unsupervised edge prediction task"""
from deepchem.models.torch_models.gnn import GNNModular
def get_multitask_classification_dataset():
featurizer = SNAPFeaturizer()
dir = os.path.dirname(os.path.abspath(__file__))

dataset, _ = get_regression_dataset()
model = GNNModular(task="edge_pred")
loss1 = model.fit(dataset, nb_epoch=5)
loss2 = model.fit(dataset, nb_epoch=5)
assert loss2 < loss1
input_file = os.path.join(dir, 'assets/multitask_example.csv')
loader = dc.data.CSVLoader(tasks=['task0', 'task1', 'task2'],
feature_field="smiles",
featurizer=featurizer)
dataset = loader.create_dataset(input_file)
metric = dc.metrics.Metric(dc.metrics.roc_auc_score,
np.mean,
mode="classification")
return dataset, metric


@pytest.mark.torch
Expand All @@ -54,3 +72,48 @@ def test_GNN_save_reload():
if hasattr(model.components[key], 'weight')
]
assert all(compare_weights(key, model, model2) for key in keys_with_weights)


@pytest.mark.torch
def test_GNN_edge_pred():
"""Tests the unsupervised edge prediction task"""
from deepchem.models.torch_models.gnn import GNNModular

dataset, _ = get_regression_dataset()
model = GNNModular(task="edge_pred")
loss1 = model.fit(dataset, nb_epoch=5)
loss2 = model.fit(dataset, nb_epoch=5)
assert loss2 < loss1


@pytest.mark.torch
def test_GNN_regression():
from deepchem.models.torch_models.gnn import GNNModular

dataset, metric = get_regression_dataset()
model = GNNModular(task="regression")
model.fit(dataset, nb_epoch=100)
scores = model.evaluate(dataset, [metric])
assert scores['mean_absolute_error'] < 0.2


@pytest.mark.torch
def test_GNN_multitask_regression():
from deepchem.models.torch_models.gnn import GNNModular

dataset, metric = get_multitask_regression_dataset()
model = GNNModular(task="regression", num_tasks=3)
model.fit(dataset, nb_epoch=100)
scores = model.evaluate(dataset, [metric])
assert scores['mean_absolute_error'] < 0.2


@pytest.mark.torch
def test_GNN_multitask_classification():
rbharath marked this conversation as resolved.
Show resolved Hide resolved
from deepchem.models.torch_models.gnn import GNNModular

dataset, metric = get_multitask_classification_dataset()
model = GNNModular(task="classification", num_tasks=3)
model.fit(dataset, nb_epoch=200)
scores = model.evaluate(dataset, [metric])
assert scores['mean-roc_auc_score'] >= 0.8
118 changes: 80 additions & 38 deletions deepchem/models/torch_models/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@
from torch_geometric.nn import GINEConv, global_add_pool, global_mean_pool, global_max_pool
from torch_geometric.nn.aggr import AttentionalAggregation, Set2Set
from torch.functional import F
from deepchem.data import Dataset
from deepchem.models.losses import SoftmaxCrossEntropy, EdgePredictionLoss
from deepchem.models.torch_models import ModularTorchModel
from deepchem.feat.graph_data import BatchGraphData
from typing import Iterable, List, Tuple
from deepchem.metrics import to_one_hot

num_atom_type = 120
num_chirality_tag = 3
Expand Down Expand Up @@ -52,10 +56,10 @@ class GNN(torch.nn.Module):
>>> smiles = ["C1=CC=CC=C1", "C1=CC=CC=C1C=O", "C1=CC=CC=C1C(=O)O"]
>>> features = featurizer.featurize(smiles)
>>> batched_graph = BatchGraphData(features).numpy_to_torch(device="cuda")
>>> modular = model = GNNModular("gin", 3, 64, 1, "attention", 0, "last", "edge_pred")
>>> modular = GNNModular(emb_dim = 8, task = "edge_pred")
>>> gnnmodel = modular.gnn
>>> print(gnnmodel(batched_graph)[0].shape)
torch.Size([23, 64])
torch.Size([23, 32])

"""

Expand Down Expand Up @@ -109,7 +113,7 @@ def forward(self, data: BatchGraphData):
h = F.dropout(F.relu(h), self.dropout, training=self.training)
h_list.append(h)

# Different implementations of JK
# Different implementations of jump_knowledge
if self.jump_knowledge == "concat":
node_representation = torch.cat(h_list, dim=1)
elif self.jump_knowledge == "last":
Expand All @@ -121,7 +125,7 @@ def forward(self, data: BatchGraphData):
h_list = [h.unsqueeze_(0) for h in h_list]
node_representation = torch.sum(torch.cat(h_list, dim=0), dim=0)[0]

return node_representation, data
return (node_representation, data)


class GNNHead(torch.nn.Module):
Expand All @@ -142,25 +146,31 @@ class GNNHead(torch.nn.Module):
Number of classes for classification.
"""

def __init__(self, pool, head):
def __init__(self, pool, head, task, num_tasks, num_classes):
super().__init__()
self.pool = pool
self.head = head
self.task = task
self.num_tasks = num_tasks
self.num_classes = num_classes

def forward(self, node_representation, data):
def forward(self, data):
"""
Forward pass for the GNN head module.

Parameters
----------
node_representation: torch.Tensor
The node representations after passing through the GNN layers.
data: BatchGraphData
The original input graph data.
data: tuple
A tuple containing the node representations and the input graph data.
node_representation is a torch.Tensor created after passing input through the GNN layers.
input_batch is the original input BatchGraphData.
"""
node_representation, input_batch = data

pooled = self.pool(node_representation, data.graph_index)
pooled = self.pool(node_representation, input_batch.graph_index)
out = self.head(pooled)
if self.task == "classification":
out = torch.reshape(out, (-1, self.num_tasks, self.num_classes))
return out


Expand Down Expand Up @@ -215,6 +225,7 @@ def __init__(self,
num_layer: int = 3,
emb_dim: int = 64,
num_tasks: int = 1,
num_classes: int = 2,
graph_pooling: str = "attention",
dropout: int = 0,
jump_knowledge: str = "concat",
Expand All @@ -223,12 +234,23 @@ def __init__(self,
self.gnn_type = gnn_type
self.num_layer = num_layer
self.emb_dim = emb_dim

self.num_tasks = num_tasks
self.num_classes = num_classes
if task == "classification":
self.output_dim = num_classes * num_tasks
self.criterion = SoftmaxCrossEntropy()._create_pytorch_loss()
elif task == "regression":
self.output_dim = num_tasks
self.criterion = F.mse_loss
elif task == "edge_pred":
self.output_dim = num_tasks
self.edge_pred_loss = EdgePredictionLoss()._create_pytorch_loss()

self.graph_pooling = graph_pooling
self.dropout = dropout
self.jump_knowledge = jump_knowledge
self.task = task
self.criterion = torch.nn.BCEWithLogitsLoss()

self.components = self.build_components()
self.model = self.build_model()
Expand Down Expand Up @@ -302,9 +324,9 @@ def build_components(self):

if self.jump_knowledge == "concat":
head = torch.nn.Linear(mult * (self.num_layer + 1) * self.emb_dim,
self.num_tasks)
self.output_dim)
else:
head = torch.nn.Linear(mult * self.emb_dim, self.num_tasks)
head = torch.nn.Linear(mult * self.emb_dim, self.output_dim)

components = {
'atom_type_embedding':
Expand All @@ -324,7 +346,8 @@ def build_components(self):
components['chirality_embedding'], components['gconvs'],
components['batch_norms'], self.dropout,
self.jump_knowledge)
self.gnn_head = GNNHead(components['pool'], components['head'])
self.gnn_head = GNNHead(components['pool'], components['head'],
self.task, self.num_tasks, self.num_classes)
return components

def build_model(self):
Expand All @@ -338,36 +361,34 @@ def build_model(self):

if self.task == "edge_pred": # unsupervised task, does not need pred head
return self.gnn
else:
elif self.task in ("regression", "classification"):
return torch.nn.Sequential(self.gnn, self.gnn_head)
else:
raise ValueError(f"Task {self.task} is not supported.")

def loss_func(self, inputs, labels, weights):
"""
The loss function executed in the training loop, which is based on the specified task.
"""
if self.task == "edge_pred":
return self.edge_pred_loss(inputs, labels, weights)

def edge_pred_loss(self, inputs, labels, weights):
"""
The loss function for the graph edge prediction task.

The inputs in this loss must be a BatchGraphData object transformed by the NegativeEdge molecule feature utility.
"""
node_emb, _ = self.model(
inputs) # node_emb shape == [num_nodes x emb_dim]

positive_score = torch.sum(node_emb[inputs.edge_index[0, ::2]] *
node_emb[inputs.edge_index[1, ::2]],
dim=1)
negative_score = torch.sum(node_emb[inputs.negative_edge_index[0]] *
node_emb[inputs.negative_edge_index[1]],
dim=1)

loss = self.criterion(
positive_score, torch.ones_like(positive_score)) + self.criterion(
negative_score, torch.zeros_like(negative_score))
return (loss * weights[0]).mean()
node_emb, inputs = self.model(inputs)
loss = self.edge_pred_loss(node_emb, inputs)
elif self.task == "regression":
loss = self.regression_loss(inputs, labels)
elif self.task == "classification":
loss = self.classification_loss(inputs, labels)
return (loss * weights).mean()

def regression_loss(self, inputs, labels):
out = self.model(inputs)
reg_loss = self.criterion(out, labels)
return reg_loss

def classification_loss(self, inputs, labels):
out = self.model(inputs)
out = F.softmax(out, dim=2)
class_loss = self.criterion(out, labels)
return class_loss

def _prepare_batch(self, batch):
"""
Expand Down Expand Up @@ -400,6 +421,27 @@ def _prepare_batch(self, batch):

return inputs, labels, weights

def default_generator(
self,
rbharath marked this conversation as resolved.
Show resolved Hide resolved
dataset: Dataset,
epochs: int = 1,
mode: str = 'fit',
deterministic: bool = True,
pad_batches: bool = True) -> Iterable[Tuple[List, List, List]]:
"""
This default generator is modified from the default generator in dc.models.tensorgraph.tensor_graph.py to support multitask classification. If the task is classification, the labels y_b are converted to a one-hot encoding and reshaped according to the number of tasks and classes.
"""

for epoch in range(epochs):
for (X_b, y_b, w_b,
ids_b) in dataset.iterbatches(batch_size=self.batch_size,
deterministic=deterministic,
pad_batches=pad_batches):
if self.task == 'classification' and y_b is not None:
y_b = to_one_hot(y_b.flatten(), self.num_classes).reshape(
-1, self.num_tasks, self.num_classes)
yield ([X_b], [y_b], [w_b])


def negative_edge_sampler(data: BatchGraphData):
"""
Expand Down
3 changes: 3 additions & 0 deletions docs/source/api_reference/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ Losses
.. autoclass:: deepchem.models.losses.GroverPretrainLoss
:members:

.. autoclass:: deepchem.models.losses.EdgePredictionLoss
:members:

Optimizers
----------

Expand Down