Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added infograph model finetuning support #3491

Merged
merged 7 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions deepchem/models/tests/test_infograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,35 @@ def test_infographstar_fit_restore():
model2.fit(dataset, nb_epoch=1, restore=True)
prediction = model2.predict_on_batch(dataset.X).reshape(-1, 1)
assert np.allclose(dataset.y, np.round(prediction))


@pytest.mark.torch
def test_infograph_pretrain_finetune(tmpdir):
from deepchem.models.torch_models.infograph import InfoGraphModel
import torch
torch.manual_seed(123)
np.random.seed(123)

dataset, _ = get_regression_dataset()
num_feat = 30
edge_dim = 11

pretrain_model = InfoGraphModel(num_feat,
edge_dim,
num_gc_layers=1,
model_dir=tmpdir,
device=torch.device('cpu'))
pretraining_loss = pretrain_model.fit(dataset, nb_epoch=1)
assert pretraining_loss
pretrain_model.save_checkpoint()

finetune_model = InfoGraphModel(num_feat,
edge_dim,
num_gc_layers=1,
task='regression',
n_tasks=1,
model_dir=tmpdir,
device=torch.device('cpu'))
finetune_model.restore(components=['encoder'])
finetuning_loss = finetune_model.fit(dataset, nb_epoch=1)
assert finetuning_loss
180 changes: 142 additions & 38 deletions deepchem/models/torch_models/infograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import GRU, Linear, ReLU, Sequential
from typing import Iterable, List, Tuple
from typing import Iterable, List, Tuple, Optional, Dict
from deepchem.metrics import to_one_hot

import deepchem as dc
Expand Down Expand Up @@ -234,9 +234,63 @@ def forward(self, data):
return g_enc, l_enc


class InfoGraphModel(ModularTorchModel):
class InfoGraphFinetune(nn.Module):
"""The finetuning module for InfoGraph model

Parameters
----------
encoder: nn.Module
An encoder to encode input graph data
fc1: nn.Module
A fully connected layer
fc2: nn.Module
A fully connected layer

Example
-------
>>> from deepchem.models.torch_models.infograph import InfoGraphModel
>>> from deepchem.feat.molecule_featurizers import MolGraphConvFeaturizer
>>> from deepchem.feat.graph_data import BatchGraphData
>>> num_feat = 30
>>> num_edge = 11
>>> infographmodular = InfoGraphModel(num_feat, num_edge, num_gc_layers=1, task='regression', n_tasks=1)
>>> smiles = ['C1=CC=CC=C1', 'C1=CC=CC=C1C2=CC=CC=C2']
>>> featurizer = MolGraphConvFeaturizer(use_edges=True)
>>> graphs = BatchGraphData(featurizer.featurize(smiles))
>>> graphs = graphs.numpy_to_torch(infographmodular.device)
>>> model = infographmodular.model
>>> predictions = model(graphs)

Reference
---------
.. Sun, F.-Y, et. al, "InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization".
"""
InfoGraph is a graph convolutional model for unsupervised graph-level representation learning. The model aims to maximize the mutual information between the representations of entire graphs and the representations of substructures of different granularity.

def __init__(self, encoder, fc1, fc2, init_emb=False):
super().__init__()
self.encoder = encoder
self.fc1 = fc1
self.fc2 = fc2
if init_emb:
self.init_emb()

def init_emb(self):
for m in self.modules():
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)

def forward(self, data):
y, _ = self.encoder(data)
return self.fc2(F.relu(self.fc1(y)))


class InfoGraphModel(ModularTorchModel):
"""InfoGraphMode

InfoGraphModel is a model which learn graph-level representation via unsupervised learning. To this end,
the model aims to maximize the mutual information between the representations of entire graphs and the representations of substructures of different granularity (eg. nodes, edges, triangles)

The unsupervised training of InfoGraph involves two encoders: one that encodes the entire graph and another that encodes substructures of different sizes. The mutual information between the two encoder outputs is maximized using a contrastive loss function.
The model randomly samples pairs of graphs and substructures, and then maximizes their mutual information by minimizing their distance in a learned embedding space.
Expand Down Expand Up @@ -273,16 +327,21 @@ class InfoGraphModel(ModularTorchModel):
>>> from deepchem.feat import MolGraphConvFeaturizer
>>> from deepchem.data import NumpyDataset
>>> import torch
>>> import tempfile
>>> tempdir = tempfile.TemporaryDirectory()
>>> smiles = ["C1CCC1", "C1=CC=CN=C1"]
>>> featurizer = MolGraphConvFeaturizer(use_edges=True)
>>> X = featurizer.featurize(smiles)
>>> y = torch.randint(0, 2, size=(2, 1)).float()
>>> w = torch.ones(size=(2, 1)).float()
>>> ds = NumpyDataset(X, y, w)
>>> num_feat = max([ds.X[i].num_node_features for i in range(len(ds))])
>>> edge_dim = max([ds.X[i].num_edge_features for i in range(len(ds))])
>>> model = InfoGraphModel(num_feat, edge_dim, 15)
>>> loss = model.fit(ds, nb_epoch=1)
>>> dataset = NumpyDataset(X, y, w)
>>> num_feat, edge_dim = 30, 11 # num feat and edge dim by molgraph conv featurizer
>>> pretrain_model = InfoGraphModel(num_feat, edge_dim, num_gc_layers=1, task='pretraining', model_dir=tempdir.name)
>>> pretraining_loss = pretrain_model.fit(dataset, nb_epoch=1)
>>> pretrain_model.save_checkpoint()
>>> finetune_model = InfoGraphModel(num_feat, edge_dim, num_gc_layers=1, task='regression', n_tasks=1, model_dir=tempdir.name)
>>> finetune_model.restore(components=['encoder'])
>>> finetuning_loss = finetune_model.fit(dataset)
"""

def __init__(self,
Expand All @@ -293,7 +352,11 @@ def __init__(self,
gamma=.1,
measure='JSD',
average_loss=True,
task='pretraining',
n_tasks: Optional[int] = None,
**kwargs):
if task == 'regression':
assert n_tasks, 'Number of prediction tasks required for building regression model'
self.num_features = num_features
self.embedding_dim = embedding_dim * num_gc_layers
self.num_gc_layers = num_gc_layers
Expand All @@ -303,6 +366,8 @@ def __init__(self,
self.average_loss = average_loss
self.localloss = LocalMutualInformationLoss()._create_pytorch_loss(
measure, average_loss)
self.task = task
self.n_tasks = n_tasks
self.components = self.build_components()
self.model = self.build_model()
super().__init__(self.model, self.components, **kwargs)
Expand All @@ -322,41 +387,62 @@ def build_components(self) -> dict:
global_d: MultilayerPerceptron, global discriminator

prior_d: MultilayerPerceptron, prior discriminator
fc1: MultilayerPerceptron, dense layer used during finetuning
fc2: MultilayerPerceptron, dense layer used during finetuning
"""
return {
'encoder':
GINEncoder(self.num_features, self.embedding_dim,
self.num_gc_layers),
'local_d':
MultilayerPerceptron(self.embedding_dim,
self.embedding_dim, (self.embedding_dim,),
skip_connection=True),
'global_d':
MultilayerPerceptron(self.embedding_dim,
self.embedding_dim, (self.embedding_dim,),
skip_connection=True),
'prior_d':
MultilayerPerceptron(self.embedding_dim,
1, (self.embedding_dim,),
activation_fn='sigmoid')
}
components: Dict[str, nn.Module] = {}
if self.task == 'pretraining':
components['encoder'] = GINEncoder(self.num_features,
self.embedding_dim,
self.num_gc_layers)
components['local_d'] = MultilayerPerceptron(self.embedding_dim,
self.embedding_dim,
(self.embedding_dim,),
skip_connection=True)
components['global_d'] = MultilayerPerceptron(self.embedding_dim,
self.embedding_dim,
(self.embedding_dim,),
skip_connection=True)
components['prior_d'] = MultilayerPerceptron(
self.embedding_dim,
1, (self.embedding_dim,),
activation_fn='sigmoid')
elif self.task == 'regression':
components['encoder'] = GINEncoder(self.num_features,
self.embedding_dim,
self.num_gc_layers)
components['fc1'] = torch.nn.Linear(self.embedding_dim,
self.embedding_dim)
# n_tasks is Optional[int] while argument 2 of nn.Linear has to be of type int
components['fc2'] = torch.nn.Linear(self.embedding_dim,
self.n_tasks) # type: ignore
return components

def build_model(self) -> nn.Module:
return InfoGraph(**self.components)
if self.task == 'pretraining':
model = InfoGraph(**self.components)
elif self.task == 'regression':
model = InfoGraphFinetune(**self.components) # type: ignore
return model

def loss_func(self, inputs, labels, weights):
y, M = self.components['encoder'](inputs)
g_enc = self.components['global_d'](y)
l_enc = self.components['local_d'](M)
local_global_loss = self.localloss(l_enc, g_enc, inputs.graph_index)
if self.prior:
prior = torch.rand_like(y)
term_a = torch.log(self.components['prior_d'](prior)).mean()
term_b = torch.log(1.0 - self.components['prior_d'](y)).mean()
prior = -(term_a + term_b) * self.gamma
else:
prior = 0
return local_global_loss + prior
if self.task == 'pretraining':
y, M = self.components['encoder'](inputs)
g_enc = self.components['global_d'](y)
l_enc = self.components['local_d'](M)
local_global_loss = self.localloss(l_enc, g_enc, inputs.graph_index)
if self.prior:
prior = torch.rand_like(y)
term_a = torch.log(self.components['prior_d'](prior)).mean()
term_b = torch.log(1.0 - self.components['prior_d'](y)).mean()
prior = -(term_a + term_b) * self.gamma
else:
prior = 0
return local_global_loss + prior
elif self.task == 'regression':
loss_fn = nn.MSELoss()
y = self.model(inputs)
return loss_fn(y, labels)

def _prepare_batch(self, batch):
"""
Expand All @@ -373,6 +459,24 @@ def _prepare_batch(self, batch):

return inputs, labels, weights

def restore( # type: ignore
self,
components: Optional[List[str]] = None,
checkpoint: Optional[str] = None,
model_dir: Optional[str] = None,
map_location: Optional[torch.device] = None) -> None:
if checkpoint is None:
checkpoints = sorted(self.get_checkpoints(model_dir))
if len(checkpoints) == 0:
raise ValueError('No checkpoint found')
checkpoint = checkpoints[0]
data = torch.load(checkpoint, map_location=map_location)
for name, state_dict in data.items():
if name != 'model' and name in self.components.keys():
self.components[name].load_state_dict(state_dict)

self.build_model()


class InfoGraphStar(torch.nn.Module):
"""
Expand Down
2 changes: 1 addition & 1 deletion deepchem/models/torch_models/modular.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ def fit_generator(self,
last_avg_loss = avg_loss
avg_loss = 0.0
averaged_batches = 0

if checkpoint_interval > 0 and current_step % checkpoint_interval == checkpoint_interval - 1:
self.save_checkpoint(max_checkpoints_to_keep)
for c in callbacks:
Expand Down Expand Up @@ -386,6 +385,7 @@ def restore( # type: ignore
model_dir: Optional[str]
The path to the model directory. If None, the model directory used to initialize the model will be used.
"""
logger.info('Restoring model')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These logger changes are also reflected in your other PR. I'm fine merging them in as part of this PR since they are small

if checkpoint is None:
checkpoints = sorted(self.get_checkpoints(model_dir))
if len(checkpoints) == 0:
Expand Down
4 changes: 4 additions & 0 deletions deepchem/models/torch_models/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import time
import logging
import os
import datetime

from deepchem.data import Dataset, NumpyDataset
from deepchem.metrics import Metric
Expand Down Expand Up @@ -971,6 +972,8 @@ def default_generator(
([inputs], [outputs], [weights])
"""
for epoch in range(epochs):
logger.info("Starting training for epoch %d at %s" %
(epoch, datetime.datetime.now().ctime()))
for (X_b, y_b, w_b,
ids_b) in dataset.iterbatches(batch_size=self.batch_size,
deterministic=deterministic,
Expand Down Expand Up @@ -1054,6 +1057,7 @@ def restore(self,
Directory to restore checkpoint from. If None, use self.model_dir. If
checkpoint is not None, this is ignored.
"""
logger.info('Restoring model')
self._ensure_built()
if checkpoint is None:
checkpoints = sorted(self.get_checkpoints(model_dir))
Expand Down
Loading