In [1]:
import torch

from torcheval.metrics import R2Score, BinaryAccuracy, BinaryAUROC

from molsetrep.utils.trainer import Trainer
from molsetrep.utils.datasets import molnet_loader
from molsetrep.utils.converters import molnet_to_pyg
from molsetrep.utils.root_mean_squared_error import RootMeanSquaredError
from molsetrep.models import GNNSetRepClassifier, GNNSetRepRegressor, GNNSetRepClassifierSubstruct

Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'pytorch_lightning'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


## Classification

In [2]:
train, valid, test = molnet_loader("bbbp")
train_loader, valid_loader, test_loader = molnet_to_pyg(train, valid, test, label_type=torch.long, imbalanced_sampler=True, secfp=True, index_graphs=False)

num_node_features = train_loader.dataset[0].num_node_features
num_edge_features = train_loader.dataset[0].num_edge_features
model = GNNSetRepClassifierSubstruct(num_node_features, 256, 1, num_edge_features, 32, 16)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
criterion = torch.nn.NLLLoss()

trainer = Trainer(
    model,
    optimizer,
    criterion,
    50,
    [BinaryAccuracy(), BinaryAUROC()],
    [BinaryAccuracy(), BinaryAUROC()],
    [BinaryAccuracy(), BinaryAUROC()],
    # scheduler=scheduler,
    monitor_metric=1,
    monitor_lower_is_better=False
)

trainer.train(train_loader, valid_loader)
trainer.test(test_loader)

*  Epoch 1: Train loss: 0.649 (BinaryAccuracy: 0.638, BinaryAUROC: 0.639)  Valid loss: 0.626 (BinaryAccuracy: 0.651, BinaryAUROC: 0.665)
*  Epoch 2: Train loss: 0.626 (BinaryAccuracy: 0.654, BinaryAUROC: 0.654)  Valid loss: 0.642 (BinaryAccuracy: 0.677, BinaryAUROC: 0.684)
*  Epoch 3: Train loss: 0.614 (BinaryAccuracy: 0.689, BinaryAUROC: 0.688)  Valid loss: 0.611 (BinaryAccuracy: 0.688, BinaryAUROC: 0.695)
|  Epoch 4: Train loss: 0.594 (BinaryAccuracy: 0.686, BinaryAUROC: 0.689)  Valid loss: 0.605 (BinaryAccuracy: 0.677, BinaryAUROC: 0.68)
*  Epoch 5: Train loss: 0.575 (BinaryAccuracy: 0.719, BinaryAUROC: 0.719)  Valid loss: 0.583 (BinaryAccuracy: 0.688, BinaryAUROC: 0.696)
|  Epoch 6: Train loss: 0.558 (BinaryAccuracy: 0.721, BinaryAUROC: 0.721)  Valid loss: 0.571 (BinaryAccuracy: 0.682, BinaryAUROC: 0.683)
|  Epoch 7: Train loss: 0.547 (BinaryAccuracy: 0.742, BinaryAUROC: 0.742)  Valid loss: 0.563 (BinaryAccuracy: 0.677, BinaryAUROC: 0.683)
|  Epoch 8: Train loss: 0.546 (BinaryAccur

## Regression

In [None]:
train, valid, test = molnet_loader("lipo")
train_loader, valid_loader, test_loader = molnet_to_pyg(train, valid, test, label_type=torch.float)

num_node_features = train_loader.dataset[0].num_node_features
num_edge_features = train_loader.dataset[0].num_edge_features
model = GNNSetRepRegressor(num_node_features, 256, 2, num_edge_features, 100, 16)

optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
criterion = torch.nn.MSELoss()

trainer = Trainer(
    model,
    optimizer,
    criterion,
    200,
    [R2Score(), RootMeanSquaredError()],
    [R2Score(), RootMeanSquaredError()],
    [R2Score(), RootMeanSquaredError()],
    monitor_metric=1
)

trainer.train(train_loader, valid_loader)
trainer.test(test_loader)