In [1]:
import torch

from torchmetrics.classification import Accuracy, AUROC
from torchmetrics.regression import R2Score, MeanSquaredError

from torch_geometric.nn import GAT

from molsetrep.utils.trainer import Trainer
from molsetrep.utils.datasets import molnet_loader
from molsetrep.utils.converters import molnet_to_pyg, smiles_to_nx, nx_to_pyg
from molsetrep.utils.root_mean_squared_error import RootMeanSquaredError
from molsetrep.models import GNNSetRepClassifier, GNNSetRepRegressor, GNNSetRepClassifierSubstruct, GNNRegressor, GNNClassifier
from molsetrep.explain import RegressionExplainer

Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'pytorch_lightning'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


In [2]:
nx_to_pyg(smiles_to_nx("C.CNC"))

Data(edge_index=[2, 4], atomic_num=[4], charge=[4], aromatic=[4], is_in_ring=[4], hydrogen_count=[4], hybridization_sp=[4], hybridization_sp2=[4], hybridization_sp3=[4], hybridization_sp3d=[4], hybridization_sp3d2=[4], chiral_type_chi_tetrahedral_cw=[4], chiral_type_chi_tetrahedral_ccw=[4], chiral_type_chi_other=[4], chiral_type_chi_tetrahedral=[4], chiral_type_chi_allene=[4], chiral_type_chi_squareplanar=[4], chiral_type_chi_trigonalbipyramidal=[4], chiral_type_chi_octahedral=[4], degree=[4], radical_count=[4], bond_type=[4], bond_type_aromatic=[4], bond_conjugated=[4], bond_stereo_z=[4], bond_stereo_e=[4], bond_stereo_cis=[4], bond_stereo_trans=[4], num_nodes=4)

## Classification

In [4]:
# Compare to https://github.com/chemprop/chemprop
train, valid, test = molnet_loader("bbbp", reload=False)
train_loader, valid_loader, test_loader = molnet_to_pyg(
    train,
    valid,
    test,
    label_type=torch.long,
    imbalanced_sampler=True,
    secfp=False,
    index_graphs=False,
    # atom_attrs=["atomic_num", "charge", "hydrogen_count"],
    # bond_attrs=["bond_type"],
)

[21:53:07] Explicit valence for atom # 1 N, 4, is greater than permitted
Failed to featurize datapoint 59, None. Appending empty array
Exception message: Python argument types in
    rdkit.Chem.rdmolfiles.CanonicalRankAtoms(NoneType)
did not match C++ signature:
    CanonicalRankAtoms(RDKit::ROMol mol, bool breakTies=True, bool includeChirality=True, bool includeIsotopes=True)
[21:53:07] Explicit valence for atom # 6 N, 4, is greater than permitted
Failed to featurize datapoint 61, None. Appending empty array
Exception message: Python argument types in
    rdkit.Chem.rdmolfiles.CanonicalRankAtoms(NoneType)
did not match C++ signature:
    CanonicalRankAtoms(RDKit::ROMol mol, bool breakTies=True, bool includeChirality=True, bool includeIsotopes=True)


[21:53:07] Explicit valence for atom # 6 N, 4, is greater than permitted
Failed to featurize datapoint 391, None. Appending empty array
Exception message: Python argument types in
    rdkit.Chem.rdmolfiles.CanonicalRankAtoms(NoneType)
did not match C++ signature:
    CanonicalRankAtoms(RDKit::ROMol mol, bool breakTies=True, bool includeChirality=True, bool includeIsotopes=True)
[21:53:07] Explicit valence for atom # 11 N, 4, is greater than permitted
Failed to featurize datapoint 614, None. Appending empty array
Exception message: Python argument types in
    rdkit.Chem.rdmolfiles.CanonicalRankAtoms(NoneType)
did not match C++ signature:
    CanonicalRankAtoms(RDKit::ROMol mol, bool breakTies=True, bool includeChirality=True, bool includeIsotopes=True)
[21:53:07] Explicit valence for atom # 12 N, 4, is greater than permitted
Failed to featurize datapoint 642, None. Appending empty array
Exception message: Python argument types in
    rdkit.Chem.rdmolfiles.CanonicalRankAtoms(NoneType)
d

In [5]:
num_node_features = train_loader.dataset[0].num_node_features
num_edge_features = train_loader.dataset[0].num_edge_features
model = GNNSetRepClassifier(num_node_features, 256, 6, num_edge_features, 8, 16)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
criterion = torch.nn.NLLLoss()

trainer = Trainer(
    model,
    optimizer,
    criterion,
    200,
    [Accuracy(task="binary"), AUROC(task="binary")],
    [Accuracy(task="binary"), AUROC(task="binary")],
    [Accuracy(task="binary"), AUROC(task="binary")],
    scheduler=scheduler,
    # monitor_metric=1,
    # monitor_lower_is_better=False
)

trainer.train(train_loader, valid_loader)
trainer.test(test_loader)

*  Epoch 1: Train loss: 0.653 (BinaryAccuracy: 0.633, BinaryAUROC: 0.627)  Valid loss: 0.564 (BinaryAccuracy: 0.781, BinaryAUROC: 0.77)
*  Epoch 2: Train loss: 0.61 (BinaryAccuracy: 0.678, BinaryAUROC: 0.672)  Valid loss: 0.53 (BinaryAccuracy: 0.807, BinaryAUROC: 0.794)
*  Epoch 3: Train loss: 0.606 (BinaryAccuracy: 0.675, BinaryAUROC: 0.673)  Valid loss: 0.484 (BinaryAccuracy: 0.844, BinaryAUROC: 0.837)
|  Epoch 4: Train loss: 0.586 (BinaryAccuracy: 0.686, BinaryAUROC: 0.688)  Valid loss: 0.559 (BinaryAccuracy: 0.807, BinaryAUROC: 0.798)
|  Epoch 5: Train loss: 0.586 (BinaryAccuracy: 0.706, BinaryAUROC: 0.702)  Valid loss: 0.496 (BinaryAccuracy: 0.807, BinaryAUROC: 0.797)
|  Epoch 6: Train loss: 0.588 (BinaryAccuracy: 0.693, BinaryAUROC: 0.691)  Valid loss: 0.506 (BinaryAccuracy: 0.87, BinaryAUROC: 0.862)
*  Epoch 7: Train loss: 0.581 (BinaryAccuracy: 0.683, BinaryAUROC: 0.684)  Valid loss: 0.448 (BinaryAccuracy: 0.823, BinaryAUROC: 0.813)
*  Epoch 8: Train loss: 0.549 (BinaryAccuracy

[]

In [5]:
num_node_features = train_loader.dataset[0].num_node_features
num_edge_features = train_loader.dataset[0].num_edge_features
model = GNNClassifier(num_node_features, 256, 6, num_edge_features)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
criterion = torch.nn.NLLLoss()

trainer = Trainer(
    model,
    optimizer,
    criterion,
    200,
    [Accuracy(task="binary"), AUROC(task="binary")],
    [Accuracy(task="binary"), AUROC(task="binary")],
    [Accuracy(task="binary"), AUROC(task="binary")],
    scheduler=scheduler,
    # monitor_metric=1,
    # monitor_lower_is_better=False
)

trainer.train(train_loader, valid_loader)
trainer.test(test_loader)

AttributeError: 'GNNClassifier' object has no attribute 'Wc'

## Regression

In [None]:
train, valid, test = molnet_loader("lipo")
train_loader, valid_loader, test_loader = molnet_to_pyg(
    train,
    valid,
    test,
    label_type=torch.float,
    # atom_attrs=[
    #     "atomic_num",
    #     "charge",
    #     "aromatic",
    #     "is_in_ring",
    #     "hydrogen_count",
    #     "hybridization_sp",
    #     "hybridization_sp2",
    #     "hybridization_sp3",
    #     "hybridization_sp3d",
    #     "hybridization_sp3d2",
    #     "chiral_type_chi_tetrahedral_cw",
    #     "chiral_type_chi_tetrahedral_ccw",
    #     "chiral_type_chi_other",
    #     "chiral_type_chi_tetrahedral",
    #     "chiral_type_chi_allene",
    #     "chiral_type_chi_squareplanar",
    #     "chiral_type_chi_trigonalbipyramidal",
    #     "chiral_type_chi_octahedral",
    #     "degree",
    #     "radical_count"
    # ]
)

num_node_features = train_loader.dataset[0].num_node_features
num_edge_features = train_loader.dataset[0].num_edge_features
model = GNNSetRepRegressor(num_node_features, 512, 2, num_edge_features, 8, 16)
# model = GNNRegressor(num_node_features, 512, 2, num_edge_features)
# model = GNNSetRepRegressor(num_node_features, 512, 2, num_edge_features, 8, 32, gnn=GAT(num_node_features, 512, 4, jk="cat", heads=8))


optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

explainer = RegressionExplainer(model, valid_loader)

trainer = Trainer(
    model,
    optimizer,
    criterion,
    200,
    [R2Score(), MeanSquaredError(squared=False)],
    [R2Score(), MeanSquaredError(squared=False)],
    [R2Score(), MeanSquaredError(squared=False)],
    # scheduler=scheduler,
    monitor_metric=1,
    # monitor_lower_is_better=False
    # explainer=explainer
)

trainer.train(train_loader, valid_loader)
trainer.test(test_loader, average_n_epochs=0)