Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix classification in infomax3d #3696

Merged
merged 3 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions deepchem/models/tests/test_gnn3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def testInfoMax3DModular():
readout_aggregators=['sum', 'mean'],
scalers=['identity'],
device=torch.device('cpu'),
task='pretraining')
task='pretraining',
learning_rate=0.00001)

loss1 = model.fit(data, nb_epoch=1)
loss2 = model.fit(data, nb_epoch=9)
Expand Down Expand Up @@ -184,10 +185,40 @@ def testInfoMax3DModularClassification():
scalers=['identity'],
task='classification',
n_tasks=1,
n_classes=1,
n_classes=2,
device=torch.device('cpu'))

model.fit(data, nb_epoch=100)
model.fit(data, nb_epoch=10)
scores = model.evaluate(data, [metric])
# FIXME We need to improve finetuning score
assert scores['mean-roc_auc_score'] > 0.7


@pytest.mark.torch
def test_infomax3d_load_from_pretrained(tmpdir):
import torch
from deepchem.models.torch_models.gnn3d import InfoMax3DModular
pretrain_model = InfoMax3DModular(hidden_dim=64,
target_dim=10,
device=torch.device('cpu'),
task='pretraining',
model_dir=tmpdir)
pretrain_model._ensure_built()
pretrain_model.save_checkpoint()
pretrain_model_state_dict = pretrain_model.model.state_dict()

finetune_model = InfoMax3DModular(hidden_dim=64,
target_dim=10,
device=torch.device('cpu'),
task='classification',
n_classes=2,
n_tasks=1)
finetune_model_old_state_dict = finetune_model.model.state_dict()
# Finetune model weights should not match before loading from pretrained model
for key, value in pretrain_model_state_dict.items():
assert not torch.allclose(value, finetune_model_old_state_dict[key])
finetune_model.load_from_pretrained(pretrain_model, components=['model2d'])
finetune_model_new_state_dict = finetune_model.model.state_dict()

# Finetune model weights should match after loading from pretrained model
for key, value in pretrain_model_state_dict.items():
assert torch.allclose(value, finetune_model_new_state_dict[key])
36 changes: 23 additions & 13 deletions deepchem/models/torch_models/gnn3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,16 @@ def __init__(self,
self.criterion = NTXentMultiplePositives()._create_pytorch_loss()
self.components = self.build_components()
self.model = self.build_model()
super().__init__(self.model, self.components, **kwargs)
if self.task == 'regression':
output_types = ['prediction']
elif self.task == 'classification':
output_types = ['prediction', 'logits']
else:
output_types = None
super().__init__(self.model,
self.components,
output_types=output_types,
**kwargs)
for module_name, module in self.components.items():
self.components[module_name] = module.to(self.device)
self.model = self.model.to(self.device)
Expand Down Expand Up @@ -525,6 +534,9 @@ def build_components(self):
dropout=self.dropout,
posttrans_layers=self.posttrans_layers,
pretrans_layers=self.pretrans_layers,
task=self.task,
n_tasks=self.n_tasks,
n_classes=self.n_classes,
**self.kwargs)
if self.task == 'pretraining':
return {
Expand Down Expand Up @@ -559,16 +571,9 @@ def build_model(self):
PNA
The 2D PNA model component.
"""
if self.task == 'pretraining':
# FIXME Pretrain uses both model2d and model3d but the super class
# can't handle two models for contrastive learning, hence we pass only model2d
return self.components['model2d']
elif self.task in ['regression', 'classification']:
if self.task == 'regression':
head = nn.Linear(self.target_dim, self.n_tasks)
elif self.task == 'classification':
head = nn.Linear(self.target_dim, self.n_tasks)
return nn.Sequential(self.components['model2d'], head)
# FIXME For pretraining task, both model2d and model3d but the super class
# can't handle two models for contrastive learning, hence we pass only model2d
return self.components['model2d']

def loss_func(self, inputs, labels, weights):
"""
Expand Down Expand Up @@ -596,8 +601,13 @@ def loss_func(self, inputs, labels, weights):
preds = self.model(inputs)
loss = F.mse_loss(preds, labels)
elif self.task == 'classification':
preds = self.model(inputs)
loss = F.binary_cross_entropy_with_logits(preds, labels)
proba, logits = self.model(inputs)
# torch's one-hot encoding works with integer data types.
# We convert labels to integer, one-hot encode and convert it back to float
# for making it suitable to loss function
labels = F.one_hot(labels.squeeze().type(torch.int64)).type(
torch.float32)
loss = F.binary_cross_entropy_with_logits(logits, labels)
return loss

def _prepare_batch(self, batch):
Expand Down
3 changes: 2 additions & 1 deletion deepchem/models/torch_models/modular.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ def load_from_pretrained( # type: ignore
if source_model is not None:
for name, module in source_model.components.items():
if components is None or name in components:
self.components[name].load_state_dict(module.state_dict())
self.components[name].load_state_dict(module.state_dict(),
strict=False)
self.build_model()

elif source_model is None:
Expand Down
32 changes: 30 additions & 2 deletions deepchem/models/torch_models/pna_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dgl
import torch
from torch import nn
import torch.nn.functional as F

from deepchem.feat.molecule_featurizers.conformer_featurizer import (
full_atom_feature_dims,
Expand Down Expand Up @@ -588,6 +589,7 @@ class PNA(nn.Module):
def __init__(self,
hidden_dim: int,
target_dim: int,
task: str,
aggregators: List[str] = ['mean'],
scalers: List[str] = ['identity'],
readout_aggregators: List[str] = ['mean'],
Expand All @@ -602,6 +604,8 @@ def __init__(self,
dropout: float = 0.0,
posttrans_layers: int = 1,
pretrans_layers: int = 1,
n_tasks: int = 1,
n_classes: int = 2,
**kwargs):
super(PNA, self).__init__()
self.node_gnn = PNAGNN(hidden_dim=hidden_dim,
Expand All @@ -619,17 +623,41 @@ def __init__(self,
if readout_hidden_dim == 1:
readout_hidden_dim = hidden_dim
self.readout_aggregators = readout_aggregators
self.output = MultilayerPerceptron(
output = MultilayerPerceptron(
d_input=hidden_dim * len(self.readout_aggregators),
d_hidden=(readout_hidden_dim,) * readout_layers,
batch_norm=False,
d_output=target_dim)

self.n_classes = n_classes
self.task, self.n_tasks = task, n_tasks
if self.task == 'regression':
head = nn.Linear(target_dim, n_tasks)
self.output = nn.Sequential(output, head)
elif self.task == 'classification':
# The model predicts unnormalized probabilities for each class and task
head = nn.Linear(target_dim, n_tasks * n_classes)
self.output = nn.Sequential(output, head)
else:
self.output = nn.Sequential(output)

def forward(self, graph: dgl.DGLGraph):
graph = self.node_gnn(graph)
readouts_to_cat = [
dgl.readout_nodes(graph, 'feat', op=aggr)
for aggr in self.readout_aggregators
]
readout = torch.cat(readouts_to_cat, dim=-1)
return self.output(readout)
outputs = self.output(readout)
if self.task == 'classification':
if self.n_tasks == 1:
softmax_dim = 1
logits = outputs.view(-1, self.n_classes)
else:
softmax_dim = 2
logits = outputs.view(-1, self.n_tasks, self.n_classes)
proba = F.softmax(logits, dim=softmax_dim)
# print (logits, proba)
return proba, logits
else:
return outputs
Loading