<a href="https://colab.research.google.com/github/ntua-unit-of-control-and-informatics/jaqpot-google-collab-examples/blob/main/test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### This example demonstrates a Graph Classification example using Aqueous Solubility dataset.

In [1]:
!pip install PyTDC --quiet
!pip install jaqpotpy --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/146.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m146.8/146.8 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m741.3 kB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.4/45.4 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m315.1/315.1 kB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.0/52.0 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m34.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Import required Libraries

In [2]:
import warnings
import torch
from torch_geometric.loader import DataLoader
from jaqpotpy import Jaqpot
from jaqpotpy.descriptors.graph import SmilesGraphFeaturizer
from jaqpotpy.datasets import SmilesGraphDataset
from jaqpotpy.models.torch_geometric_models.graph_neural_network import GraphSageNetwork, pyg_to_onnx
from jaqpotpy.models.trainers.graph_trainers import BinaryGraphModelTrainer

# This is do disable any rdkit warnings
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
warnings.simplefilter("ignore")

Load the regression data. We use TDC library to obtain train, validation, and test splits with SMILES and Solubility.

In [4]:
from tdc.single_pred import Tox
data = Tox(name = 'AMES')
data_splits = data.get_split()

def split_to_list(split):
    return data_splits[split]['Drug'].to_list(), data_splits[split]['Y'].to_list()

# List of smiles and endpoints
train_smiles , train_y =  split_to_list('train')
val_smiles, val_y = split_to_list('valid')
test_smiles , test_y =  split_to_list('test')

Found local copy...
Loading...
Done!


Firstly, a SmilesGraphFeaturizer instance is created. We add 4 features in the featurizer and their specified value (if needed).

In [5]:
featurizer = SmilesGraphFeaturizer(include_edge_features=False)
featurizer.add_atom_feature(
    "symbol",
    ['C', 'K', 'S', 'Br', 'H', 'F', 'N', 'Cl', 'P', 'O', 'I', 'UNK']
)
featurizer.add_atom_feature("total_num_hs", [0, 1, 2, 3, 4])
featurizer.add_atom_feature("degree", [0, 1, 2, 3, 4, 5, 6])
featurizer.add_atom_feature("is_aromatic")

Create train, validation and testing datasets

In [6]:
train_dataset = SmilesGraphDataset(
    smiles=train_smiles, y=train_y, featurizer=featurizer
)
val_dataset = SmilesGraphDataset(
    smiles=val_smiles, y=val_y, featurizer=featurizer
)
test_dataset = SmilesGraphDataset(
    smiles=test_smiles, y=test_y, featurizer=featurizer
)

train_dataset.precompute_featurization()
val_dataset.precompute_featurization()
test_dataset.precompute_featurization()

Create a model. For this demonstration GraphConvolutionNetwork is used with only node features.

In [7]:
node_features = featurizer.get_num_node_features()
model = GraphSageNetwork(
    input_dim=node_features,
    hidden_layers=1,
    hidden_dim=32,
    output_dim=1,
    activation=torch.nn.ReLU(),
    dropout_proba=0.2,
    batch_norm = False,
    seed = 42,
    pooling="add",
)

ReLU()


Define optimizer and loss function

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
loss = torch.nn.BCEWithLogitsLoss()

Create instance of Regression Trainer

In [9]:
trainer = BinaryGraphModelTrainer(
    model=model, # Jaqpotpy Graph Model
    n_epochs=20,
    optimizer=optimizer,
    loss_fn=loss,
    scheduler=None
)

Create PyTorch geometric dataloaders

In [10]:
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

Train and evaluate on the validation set for 20 epochs

In [11]:
trainer.train(train_loader, val_loader)

Epoch 1/20: 100%|██████████| 40/40 [00:00<00:00, 60.59it/s, loss=0.749]


 Train: loss=0.7486 | accuracy=0.6199 | balanced_accuracy=0.6116 | precision=0.6325 | recall=0.7119 | f1=0.6698 | mcc=0.2280 | roc_auc=0.6750


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.7486 | accuracy=0.6199 | balanced_accuracy=0.6116 | precision=0.6325 | recall=0.7119 | f1=0.6698 | mcc=0.2280 | roc_auc=0.6750


 Val:   loss=0.6378 | accuracy=0.6181 | balanced_accuracy=0.6025 | precision=0.6534 | recall=0.7098 | f1=0.6805 | mcc=0.2092 | roc_auc=0.6855


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.6378 | accuracy=0.6181 | balanced_accuracy=0.6025 | precision=0.6534 | recall=0.7098 | f1=0.6805 | mcc=0.2092 | roc_auc=0.6855
Epoch 2/20: 100%|██████████| 40/40 [00:00<00:00, 60.47it/s, loss=0.658]


 Train: loss=0.6577 | accuracy=0.6378 | balanced_accuracy=0.6277 | precision=0.6419 | recall=0.7492 | f1=0.6914 | mcc=0.2639 | roc_auc=0.7118


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.6577 | accuracy=0.6378 | balanced_accuracy=0.6277 | precision=0.6419 | recall=0.7492 | f1=0.6914 | mcc=0.2639 | roc_auc=0.7118


 Val:   loss=0.6215 | accuracy=0.6511 | balanced_accuracy=0.6321 | precision=0.6723 | recall=0.7626 | f1=0.7146 | mcc=0.2740 | roc_auc=0.7119


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.6215 | accuracy=0.6511 | balanced_accuracy=0.6321 | precision=0.6723 | recall=0.7626 | f1=0.7146 | mcc=0.2740 | roc_auc=0.7119
Epoch 3/20: 100%|██████████| 40/40 [00:00<00:00, 65.48it/s, loss=0.638]


 Train: loss=0.6376 | accuracy=0.6592 | balanced_accuracy=0.6508 | precision=0.6637 | recall=0.7517 | f1=0.7050 | mcc=0.3086 | roc_auc=0.7322


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.6376 | accuracy=0.6592 | balanced_accuracy=0.6508 | precision=0.6637 | recall=0.7517 | f1=0.7050 | mcc=0.3086 | roc_auc=0.7322


 Val:   loss=0.6112 | accuracy=0.6703 | balanced_accuracy=0.6562 | precision=0.6962 | recall=0.7530 | f1=0.7235 | mcc=0.3184 | roc_auc=0.7293


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.6112 | accuracy=0.6703 | balanced_accuracy=0.6562 | precision=0.6962 | recall=0.7530 | f1=0.7235 | mcc=0.3184 | roc_auc=0.7293
Epoch 4/20: 100%|██████████| 40/40 [00:00<00:00, 62.44it/s, loss=0.62]


 Train: loss=0.6200 | accuracy=0.6741 | balanced_accuracy=0.6670 | precision=0.6799 | recall=0.7528 | f1=0.7145 | mcc=0.3396 | roc_auc=0.7511


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.6200 | accuracy=0.6741 | balanced_accuracy=0.6670 | precision=0.6799 | recall=0.7528 | f1=0.7145 | mcc=0.3396 | roc_auc=0.7511


 Val:   loss=0.5989 | accuracy=0.6827 | balanced_accuracy=0.6699 | precision=0.7085 | recall=0.7578 | f1=0.7323 | mcc=0.3450 | roc_auc=0.7469


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5989 | accuracy=0.6827 | balanced_accuracy=0.6699 | precision=0.7085 | recall=0.7578 | f1=0.7323 | mcc=0.3450 | roc_auc=0.7469
Epoch 5/20: 100%|██████████| 40/40 [00:00<00:00, 51.88it/s, loss=0.607]


 Train: loss=0.6071 | accuracy=0.6930 | balanced_accuracy=0.6906 | precision=0.7156 | recall=0.7187 | f1=0.7172 | mcc=0.3814 | roc_auc=0.7666


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.6071 | accuracy=0.6930 | balanced_accuracy=0.6906 | precision=0.7156 | recall=0.7187 | f1=0.7172 | mcc=0.3814 | roc_auc=0.7666


 Val:   loss=0.5932 | accuracy=0.7005 | balanced_accuracy=0.6986 | precision=0.7519 | recall=0.7122 | f1=0.7315 | mcc=0.3943 | roc_auc=0.7583


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5932 | accuracy=0.7005 | balanced_accuracy=0.6986 | precision=0.7519 | recall=0.7122 | f1=0.7315 | mcc=0.3943 | roc_auc=0.7583
Epoch 6/20: 100%|██████████| 40/40 [00:00<00:00, 52.82it/s, loss=0.598]


 Train: loss=0.5978 | accuracy=0.7118 | balanced_accuracy=0.7088 | precision=0.7290 | recall=0.7448 | f1=0.7368 | mcc=0.4186 | roc_auc=0.7799


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5978 | accuracy=0.7118 | balanced_accuracy=0.7088 | precision=0.7290 | recall=0.7448 | f1=0.7368 | mcc=0.4186 | roc_auc=0.7799


 Val:   loss=0.5800 | accuracy=0.7088 | balanced_accuracy=0.7041 | precision=0.7506 | recall=0.7362 | f1=0.7433 | mcc=0.4070 | roc_auc=0.7712


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5800 | accuracy=0.7088 | balanced_accuracy=0.7041 | precision=0.7506 | recall=0.7362 | f1=0.7433 | mcc=0.4070 | roc_auc=0.7712
Epoch 7/20: 100%|██████████| 40/40 [00:00<00:00, 45.48it/s, loss=0.591]


 Train: loss=0.5913 | accuracy=0.7053 | balanced_accuracy=0.7153 | precision=0.8105 | recall=0.5951 | f1=0.6863 | mcc=0.4385 | roc_auc=0.7899


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5913 | accuracy=0.7053 | balanced_accuracy=0.7153 | precision=0.8105 | recall=0.5951 | f1=0.6863 | mcc=0.4385 | roc_auc=0.7899


 Val:   loss=0.6011 | accuracy=0.6799 | balanced_accuracy=0.6977 | precision=0.8108 | recall=0.5755 | f1=0.6732 | mcc=0.3983 | roc_auc=0.7775


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.6011 | accuracy=0.6799 | balanced_accuracy=0.6977 | precision=0.8108 | recall=0.5755 | f1=0.6732 | mcc=0.3983 | roc_auc=0.7775
Epoch 8/20: 100%|██████████| 40/40 [00:00<00:00, 62.97it/s, loss=0.582]


 Train: loss=0.5815 | accuracy=0.7252 | balanced_accuracy=0.7253 | precision=0.7577 | recall=0.7242 | f1=0.7405 | mcc=0.4492 | roc_auc=0.7960


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5815 | accuracy=0.7252 | balanced_accuracy=0.7253 | precision=0.7577 | recall=0.7242 | f1=0.7405 | mcc=0.4492 | roc_auc=0.7960


 Val:   loss=0.5680 | accuracy=0.7253 | balanced_accuracy=0.7246 | precision=0.7775 | recall=0.7290 | f1=0.7525 | mcc=0.4457 | roc_auc=0.7876


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5680 | accuracy=0.7253 | balanced_accuracy=0.7246 | precision=0.7775 | recall=0.7290 | f1=0.7525 | mcc=0.4457 | roc_auc=0.7876
Epoch 9/20: 100%|██████████| 40/40 [00:00<00:00, 62.94it/s, loss=0.574]


 Train: loss=0.5739 | accuracy=0.6888 | balanced_accuracy=0.7044 | precision=0.8486 | recall=0.5179 | f1=0.6433 | mcc=0.4329 | roc_auc=0.7986


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5739 | accuracy=0.6888 | balanced_accuracy=0.7044 | precision=0.8486 | recall=0.5179 | f1=0.6433 | mcc=0.4329 | roc_auc=0.7986


 Val:   loss=0.6110 | accuracy=0.6662 | balanced_accuracy=0.6898 | precision=0.8271 | recall=0.5276 | f1=0.6442 | mcc=0.3900 | roc_auc=0.7854


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.6110 | accuracy=0.6662 | balanced_accuracy=0.6898 | precision=0.8271 | recall=0.5276 | f1=0.6442 | mcc=0.3900 | roc_auc=0.7854
Epoch 10/20: 100%|██████████| 40/40 [00:00<00:00, 62.38it/s, loss=0.572]


 Train: loss=0.5716 | accuracy=0.7352 | balanced_accuracy=0.7322 | precision=0.7493 | recall=0.7680 | f1=0.7585 | mcc=0.4656 | roc_auc=0.8040


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5716 | accuracy=0.7352 | balanced_accuracy=0.7322 | precision=0.7493 | recall=0.7680 | f1=0.7585 | mcc=0.4656 | roc_auc=0.8040


 Val:   loss=0.5533 | accuracy=0.7294 | balanced_accuracy=0.7241 | precision=0.7657 | recall=0.7602 | f1=0.7629 | mcc=0.4478 | roc_auc=0.7974


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5533 | accuracy=0.7294 | balanced_accuracy=0.7241 | precision=0.7657 | recall=0.7602 | f1=0.7629 | mcc=0.4478 | roc_auc=0.7974
Epoch 11/20: 100%|██████████| 40/40 [00:00<00:00, 48.12it/s, loss=0.565]


 Train: loss=0.5652 | accuracy=0.7326 | balanced_accuracy=0.7365 | precision=0.7895 | recall=0.6905 | f1=0.7367 | mcc=0.4719 | roc_auc=0.8097


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5652 | accuracy=0.7326 | balanced_accuracy=0.7365 | precision=0.7895 | recall=0.6905 | f1=0.7367 | mcc=0.4719 | roc_auc=0.8097


 Val:   loss=0.5578 | accuracy=0.7431 | balanced_accuracy=0.7484 | precision=0.8159 | recall=0.7122 | f1=0.7606 | mcc=0.4915 | roc_auc=0.8013


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5578 | accuracy=0.7431 | balanced_accuracy=0.7484 | precision=0.8159 | recall=0.7122 | f1=0.7606 | mcc=0.4915 | roc_auc=0.8013
Epoch 12/20: 100%|██████████| 40/40 [00:00<00:00, 47.35it/s, loss=0.561]


 Train: loss=0.5609 | accuracy=0.7397 | balanced_accuracy=0.7406 | precision=0.7761 | recall=0.7300 | f1=0.7523 | mcc=0.4796 | roc_auc=0.8118


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5609 | accuracy=0.7397 | balanced_accuracy=0.7406 | precision=0.7761 | recall=0.7300 | f1=0.7523 | mcc=0.4796 | roc_auc=0.8118


 Val:   loss=0.5481 | accuracy=0.7335 | balanced_accuracy=0.7343 | precision=0.7896 | recall=0.7290 | f1=0.7581 | mcc=0.4643 | roc_auc=0.8033


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5481 | accuracy=0.7335 | balanced_accuracy=0.7343 | precision=0.7896 | recall=0.7290 | f1=0.7581 | mcc=0.4643 | roc_auc=0.8033
Epoch 13/20: 100%|██████████| 40/40 [00:00<00:00, 50.26it/s, loss=0.557]


 Train: loss=0.5566 | accuracy=0.7242 | balanced_accuracy=0.7340 | precision=0.8306 | recall=0.6165 | f1=0.7077 | mcc=0.4755 | roc_auc=0.8163


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5566 | accuracy=0.7242 | balanced_accuracy=0.7340 | precision=0.8306 | recall=0.6165 | f1=0.7077 | mcc=0.4755 | roc_auc=0.8163


 Val:   loss=0.5708 | accuracy=0.7129 | balanced_accuracy=0.7277 | precision=0.8312 | recall=0.6259 | f1=0.7141 | mcc=0.4549 | roc_auc=0.8039


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5708 | accuracy=0.7129 | balanced_accuracy=0.7277 | precision=0.8312 | recall=0.6259 | f1=0.7141 | mcc=0.4549 | roc_auc=0.8039
Epoch 14/20: 100%|██████████| 40/40 [00:00<00:00, 70.34it/s, loss=0.55]


 Train: loss=0.5502 | accuracy=0.7424 | balanced_accuracy=0.7469 | precision=0.8041 | recall=0.6934 | f1=0.7446 | mcc=0.4932 | roc_auc=0.8194


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5502 | accuracy=0.7424 | balanced_accuracy=0.7469 | precision=0.8041 | recall=0.6934 | f1=0.7446 | mcc=0.4932 | roc_auc=0.8194


 Val:   loss=0.5501 | accuracy=0.7376 | balanced_accuracy=0.7436 | precision=0.8139 | recall=0.7026 | f1=0.7542 | mcc=0.4820 | roc_auc=0.8093


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5501 | accuracy=0.7376 | balanced_accuracy=0.7436 | precision=0.8139 | recall=0.7026 | f1=0.7542 | mcc=0.4820 | roc_auc=0.8093
Epoch 15/20: 100%|██████████| 40/40 [00:00<00:00, 68.20it/s, loss=0.555]


 Train: loss=0.5546 | accuracy=0.7338 | balanced_accuracy=0.7415 | precision=0.8222 | recall=0.6488 | f1=0.7253 | mcc=0.4865 | roc_auc=0.8209


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5546 | accuracy=0.7338 | balanced_accuracy=0.7415 | precision=0.8222 | recall=0.6488 | f1=0.7253 | mcc=0.4865 | roc_auc=0.8209


 Val:   loss=0.5580 | accuracy=0.7294 | balanced_accuracy=0.7409 | precision=0.8313 | recall=0.6619 | f1=0.7370 | mcc=0.4785 | roc_auc=0.8102


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5580 | accuracy=0.7294 | balanced_accuracy=0.7409 | precision=0.8313 | recall=0.6619 | f1=0.7370 | mcc=0.4785 | roc_auc=0.8102
Epoch 16/20: 100%|██████████| 40/40 [00:01<00:00, 27.49it/s, loss=0.548]


 Train: loss=0.5481 | accuracy=0.7515 | balanced_accuracy=0.7501 | precision=0.7725 | recall=0.7669 | f1=0.7697 | mcc=0.4998 | roc_auc=0.8246


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5481 | accuracy=0.7515 | balanced_accuracy=0.7501 | precision=0.7725 | recall=0.7669 | f1=0.7697 | mcc=0.4998 | roc_auc=0.8246


 Val:   loss=0.5348 | accuracy=0.7527 | balanced_accuracy=0.7494 | precision=0.7912 | recall=0.7722 | f1=0.7816 | mcc=0.4970 | roc_auc=0.8142


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5348 | accuracy=0.7527 | balanced_accuracy=0.7494 | precision=0.7912 | recall=0.7722 | f1=0.7816 | mcc=0.4970 | roc_auc=0.8142
Epoch 17/20: 100%|██████████| 40/40 [00:01<00:00, 22.67it/s, loss=0.552]


 Train: loss=0.5523 | accuracy=0.7556 | balanced_accuracy=0.7549 | precision=0.7804 | recall=0.7637 | f1=0.7719 | mcc=0.5089 | roc_auc=0.8266


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5523 | accuracy=0.7556 | balanced_accuracy=0.7549 | precision=0.7804 | recall=0.7637 | f1=0.7719 | mcc=0.5089 | roc_auc=0.8266


 Val:   loss=0.5336 | accuracy=0.7500 | balanced_accuracy=0.7479 | precision=0.7930 | recall=0.7626 | f1=0.7775 | mcc=0.4930 | roc_auc=0.8161


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5336 | accuracy=0.7500 | balanced_accuracy=0.7479 | precision=0.7930 | recall=0.7626 | f1=0.7775 | mcc=0.4930 | roc_auc=0.8161
Epoch 18/20: 100%|██████████| 40/40 [00:01<00:00, 22.31it/s, loss=0.543]


 Train: loss=0.5429 | accuracy=0.7440 | balanced_accuracy=0.7338 | precision=0.7224 | recall=0.8565 | f1=0.7837 | mcc=0.4860 | roc_auc=0.8219


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5429 | accuracy=0.7440 | balanced_accuracy=0.7338 | precision=0.7224 | recall=0.8565 | f1=0.7837 | mcc=0.4860 | roc_auc=0.8219


 Val:   loss=0.5364 | accuracy=0.7637 | balanced_accuracy=0.7455 | precision=0.7547 | recall=0.8705 | f1=0.8085 | mcc=0.5131 | roc_auc=0.8145


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5364 | accuracy=0.7637 | balanced_accuracy=0.7455 | precision=0.7547 | recall=0.8705 | f1=0.8085 | mcc=0.5131 | roc_auc=0.8145
Epoch 19/20: 100%|██████████| 40/40 [00:01<00:00, 31.11it/s, loss=0.546]


 Train: loss=0.5464 | accuracy=0.7454 | balanced_accuracy=0.7505 | precision=0.8124 | recall=0.6890 | f1=0.7456 | mcc=0.5009 | roc_auc=0.8283


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5464 | accuracy=0.7454 | balanced_accuracy=0.7505 | precision=0.8124 | recall=0.6890 | f1=0.7456 | mcc=0.5009 | roc_auc=0.8283


 Val:   loss=0.5465 | accuracy=0.7418 | balanced_accuracy=0.7492 | precision=0.8244 | recall=0.6978 | f1=0.7558 | mcc=0.4934 | roc_auc=0.8136


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5465 | accuracy=0.7418 | balanced_accuracy=0.7492 | precision=0.8244 | recall=0.6978 | f1=0.7558 | mcc=0.4934 | roc_auc=0.8136
Epoch 20/20: 100%|██████████| 40/40 [00:01<00:00, 32.32it/s, loss=0.537]


 Train: loss=0.5372 | accuracy=0.7501 | balanced_accuracy=0.7546 | precision=0.8122 | recall=0.7006 | f1=0.7523 | mcc=0.5085 | roc_auc=0.8310


INFO:jaqpotpy.models.trainers.base_trainer: Train: loss=0.5372 | accuracy=0.7501 | balanced_accuracy=0.7546 | precision=0.8122 | recall=0.7006 | f1=0.7523 | mcc=0.5085 | roc_auc=0.8310


 Val:   loss=0.5383 | accuracy=0.7486 | balanced_accuracy=0.7556 | precision=0.8287 | recall=0.7074 | f1=0.7633 | mcc=0.5060 | roc_auc=0.8188


INFO:jaqpotpy.models.trainers.base_trainer: Val:   loss=0.5383 | accuracy=0.7486 | balanced_accuracy=0.7556 | precision=0.8287 | recall=0.7074 | f1=0.7633 | mcc=0.5060 | roc_auc=0.8188


Evaluate and show metrics on the test dataset

In [12]:
loss, metrics, conf_matrix = trainer.evaluate(test_loader)

In [13]:
metrics

{'accuracy': 0.7376373626373627,
 'balanced_accuracy': 0.7437210046392577,
 'precision': 0.8104477611940298,
 'recall': 0.6804511278195489,
 'f1': 0.7397820163487738,
 'mcc': 0.4867306256738066,
 'roc_auc': 0.8269381660839028,
 'loss': 0.5275376163996183}

In [14]:
conf_matrix

array([[531, 127],
       [255, 543]])

Convert PyTorch Geomtric model to onnx format

In [15]:
onnx_model = pyg_to_onnx(model, featurizer)

Log into jaqpot

In [16]:
jaqpot = Jaqpot()
jaqpot.login()

Open this URL in your browser and log in:
https://login.jaqpot.org/realms/jaqpot/protocol/openid-connect/auth?client_id=jaqpot-client&response_type=code&redirect_uri=urn:ietf:wg:oauth:2.0:oob&scope=openid email profile&state=random_state_value
Enter the authorization code you received: ad4b99b3-d552-4a19-a2ec-5f6ed2c4fc0c.2b052e5d-ef71-411f-bdd8-fb42ad1cbc5e.40e0db1a-58ce-461a-8fbb-6a4451d8587a


Deploy the model on the web

In [None]:
jaqpot.deploy_torch_model(
    onnx_model,
    featurizer=featurizer,  # Featurizer used for the model
    name="Graph Neural Network",
    description="Graph Sage Network for AMES mutagenicity classification",
    target_name="SOLUBILITY",
    visibility="PRIVATE",
    task="binary_classification",  # Specify the task (regression or binary_classification)
)

[1m [32m 2024-11-12 08:40:22,427 - INFO - Model has been successfully uploaded. The url of the model is https://app.jaqpot.org/dashboard/models/1911[0m
