# Training a GCN with NAGL

## Imports

In [1]:
import os
import pickle
from pathlib import Path

import tqdm

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar

from openff.toolkit import Molecule


from openff.nagl import GNNModel
from openff.nagl.features import atoms, bonds
from openff.nagl.nn.dataset import DGLMoleculeLightningDataModule
from openff.nagl.storage.record import MoleculeRecord
from openff.nagl.storage import MoleculeStore


  from .autonotebook import tqdm as notebook_tqdm


## Configuration

In [2]:
output_directory = Path('output') # The path to an output directory

n_epochs = 200 # Number of epochs"

n_gpus = 1 # Number of gpus

partial_charge_method = "am1bcc" # Method

# dataset_paths = [Path("alkanes.sqlite")]
dataset_paths = [Path("../prepare-dataset/alkanes_by_josh.sqlite")]


## Put together a test dataset

In [3]:
test_smiles = [
    "CCCCCCC",
    "CC(C)C(C)C",
    "CC(C)(C)C",
]

records = []

for smiles in tqdm.tqdm(test_smiles, desc="Labeling molecules"):
    mol = Molecule.from_smiles(smiles, allow_undefined_stereo=True)
    record = MoleculeRecord.from_openff(
        mol,
        partial_charge_methods=["am1bcc", "am1"],
        generate_conformers=True,
        n_conformer_pool=500,
        n_conformers=10,
        rms_cutoff=0.05,
    )
    records.append(record)
    
test_set_path = Path("my_first_test_dataset.sqlite")
if test_set_path.exists():
    test_set_path.unlink()

store = MoleculeStore(test_set_path)
store.store(records)

Labeling molecules: 100%|█████████████████████████████████████████████| 3/3 [00:00<00:00,  4.36it/s]
grouping records to store by InChI key: 100%|████████████████████████| 3/3 [00:00<00:00, 569.65it/s]
storing grouped records: 100%|███████████████████████████████████████| 3/3 [00:00<00:00, 325.29it/s]


## Create the model

In [4]:
atom_features = (
    atoms.AtomicElement(["C", "O", "H", "N", "S", "F", "Br", "Cl", "I", "P"]),
    atoms.AtomConnectivity(),
    atoms.AtomAverageFormalCharge(),
    atoms.AtomHybridization(),
    atoms.AtomInRingOfSize(3),
    atoms.AtomInRingOfSize(4),
    atoms.AtomInRingOfSize(5),
    atoms.AtomInRingOfSize(6),
)

bond_features = (
    bonds.BondInRingOfSize(3),
    bonds.BondInRingOfSize(4),
    bonds.BondInRingOfSize(5),
    bonds.BondInRingOfSize(6),
)

In [5]:

model = GNNModel(
    convolution_architecture="SAGEConv",
    n_convolution_hidden_features=128,
    n_convolution_layers=3,
    n_readout_hidden_features=128,
    n_readout_layers=4,
    activation_function="ReLU",
    postprocess_layer="compute_partial_charges",
    readout_name=f"{partial_charge_method}-charges",
    learning_rate=0.001,
    atom_features=atom_features,
    bond_features=bond_features,
)

## Specify the training, validation and test data

- **training**: Data the model is trained against

- **validation**: Data used to validate the model as it is trained

- **tests**: Data used to test that the final model is good

In [21]:
!rm -r output data

data_module = DGLMoleculeLightningDataModule(
    atom_features=atom_features,
    bond_features=bond_features,
    partial_charge_method=partial_charge_method,
    training_set_paths=dataset_paths,
    validation_set_paths=dataset_paths,
    test_set_paths=[test_set_path],
)

rm: cannot remove 'output': No such file or directory


## Train the model

In [22]:
os.makedirs(str(output_directory), exist_ok=True)

logger = TensorBoardLogger(output_directory)

callbacks = [TQDMProgressBar(), ModelCheckpoint(save_top_k=1, monitor="val_loss")]

trainer = Trainer(
#     gpus=n_gpus,
    min_epochs=n_epochs,
    max_epochs=n_epochs,
    logger=logger,
    callbacks=callbacks,
)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [23]:
trainer.fit(
    model, 
    datamodule=data_module, 
)

featurizing molecules: 100%|████████████████████████████████████████| 10/10 [00:00<00:00, 60.40it/s]
featurizing molecules: 100%|██████████████████████████████████████████| 3/3 [00:00<00:00, 52.01it/s]
Missing logger folder: output/lightning_logs

  | Name               | Type              | Params
---------------------------------------------------------
0 | convolution_module | ConvolutionModule | 72.3 K
1 | readout_modules    | ModuleDict        | 66.3 K
---------------------------------------------------------
138 K     Trainable params
0         Non-trainable params
138 K     Total params
0.555     Total estimated model params size (MB)


Epoch 0:  50%|████████████████                | 1/2 [00:00<00:00, 125.44it/s, loss=0.00182, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 0: 100%|████████████████████████████████| 2/2 [00:00<00:00, 120.47it/s, loss=0.00182, v_num=0][A
Epoch 1:  50%|████████████████▌                | 1/2 [00:00<00:00, 134.98it/s, loss=0.0267, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 1: 100%|█████████████████████████████████| 2/2 [00:00<00:00, 127.08it/s, loss=0.0267, v_num=0][A
Epoch 2:  50%|████████████████▌                | 1/2 [00:00<00:00, 132.55it/s, loss=0.0275, v_num=0][A

Epoch 18:  50%|████████████████                | 1/2 [00:00<00:00, 142.55it/s, loss=0.0103, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 18: 100%|████████████████████████████████| 2/2 [00:00<00:00, 124.89it/s, loss=0.0103, v_num=0][A
Epoch 19:  50%|███████████████▌               | 1/2 [00:00<00:00, 137.99it/s, loss=0.00983, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 19: 100%|███████████████████████████████| 2/2 [00:00<00:00, 128.84it/s, loss=0.00983, v_num=0][A
Epoch 20:  50%|███████████████▌               | 1/2 [00:00<00:00, 146.52it/s, loss=0.00986, v_num=0]

Epoch 36:  50%|███████████████▌               | 1/2 [00:00<00:00, 112.49it/s, loss=0.00151, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 36: 100%|███████████████████████████████| 2/2 [00:00<00:00, 104.95it/s, loss=0.00151, v_num=0][A
Epoch 37:  50%|███████████████▌               | 1/2 [00:00<00:00, 137.46it/s, loss=0.00135, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 37: 100%|███████████████████████████████| 2/2 [00:00<00:00, 123.95it/s, loss=0.00135, v_num=0][A
Epoch 38:  50%|████████████████                | 1/2 [00:00<00:00, 136.31it/s, loss=0.0012, v_num=0]

Epoch 54:  50%|███████████████▌               | 1/2 [00:00<00:00, 120.13it/s, loss=0.00103, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 54: 100%|███████████████████████████████| 2/2 [00:00<00:00, 112.38it/s, loss=0.00103, v_num=0][A
Epoch 55:  50%|███████████████▌               | 1/2 [00:00<00:00, 112.78it/s, loss=0.00101, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 55: 100%|███████████████████████████████| 2/2 [00:00<00:00, 108.26it/s, loss=0.00101, v_num=0][A
Epoch 56:  50%|███████████████               | 1/2 [00:00<00:00, 111.75it/s, loss=0.000973, v_num=0]

Epoch 72:  50%|███████████████▌               | 1/2 [00:00<00:00, 142.85it/s, loss=0.00155, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 72: 100%|███████████████████████████████| 2/2 [00:00<00:00, 130.77it/s, loss=0.00155, v_num=0][A
Epoch 73:  50%|████████████████                | 1/2 [00:00<00:00, 143.38it/s, loss=0.0016, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 73: 100%|████████████████████████████████| 2/2 [00:00<00:00, 131.30it/s, loss=0.0016, v_num=0][A
Epoch 74:  50%|███████████████▌               | 1/2 [00:00<00:00, 116.07it/s, loss=0.00165, v_num=0]

Epoch 90:  50%|███████████████▌               | 1/2 [00:00<00:00, 142.41it/s, loss=0.00277, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 90: 100%|███████████████████████████████| 2/2 [00:00<00:00, 127.87it/s, loss=0.00277, v_num=0][A
Epoch 91:  50%|███████████████▌               | 1/2 [00:00<00:00, 113.53it/s, loss=0.00278, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 91: 100%|███████████████████████████████| 2/2 [00:00<00:00, 103.31it/s, loss=0.00278, v_num=0][A
Epoch 92:  50%|███████████████▌               | 1/2 [00:00<00:00, 110.14it/s, loss=0.00273, v_num=0]

Epoch 108:  50%|██████████████▌              | 1/2 [00:00<00:00, 119.42it/s, loss=0.000912, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 108: 100%|█████████████████████████████| 2/2 [00:00<00:00, 112.99it/s, loss=0.000912, v_num=0][A
Epoch 109:  50%|███████████████               | 1/2 [00:00<00:00, 117.47it/s, loss=0.00091, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 109: 100%|██████████████████████████████| 2/2 [00:00<00:00, 110.99it/s, loss=0.00091, v_num=0][A
Epoch 110:  50%|██████████████▌              | 1/2 [00:00<00:00, 120.91it/s, loss=0.000868, v_num=0]

Epoch 126:  50%|███████████████               | 1/2 [00:00<00:00, 118.67it/s, loss=0.00194, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 126: 100%|██████████████████████████████| 2/2 [00:00<00:00, 112.31it/s, loss=0.00194, v_num=0][A
Epoch 127:  50%|███████████████               | 1/2 [00:00<00:00, 116.02it/s, loss=0.00216, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 127: 100%|██████████████████████████████| 2/2 [00:00<00:00, 109.95it/s, loss=0.00216, v_num=0][A
Epoch 128:  50%|███████████████               | 1/2 [00:00<00:00, 117.89it/s, loss=0.00231, v_num=0]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch 180:  50%|███████████████               | 1/2 [00:00<00:00, 147.08it/s, loss=0.00165, v_num=0]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 180: 100%|██████████████████████████████| 2/2 [00:00<00:00, 129.80it/s, loss=0.00165, v_num=0][A
Epoch 181:  50%|███████████████               | 1/2 [00:00<00:00, 131.92it/s, loss=0.00156, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 181: 100%|██████████████████████████████| 2/2 [00:00<00:00, 124.12it/s, loss=0.00156, v_num=0][A
Epoch 182:  50%|███████████████               | 1/2 [00:00<00:00, 141.76it/s, loss=0.00158, v_num=0][A

Epoch 198:  50%|███████████████               | 1/2 [00:00<00:00, 146.03it/s, loss=0.00163, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 198: 100%|██████████████████████████████| 2/2 [00:00<00:00, 133.27it/s, loss=0.00163, v_num=0][A
Epoch 199:  50%|███████████████               | 1/2 [00:00<00:00, 147.50it/s, loss=0.00155, v_num=0][A
Validation: 0it [00:00, ?it/s][A
Validation:   0%|                                                             | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                | 0/1 [00:00<?, ?it/s][A
Epoch 199: 100%|██████████████████████████████| 2/2 [00:00<00:00, 130.54it/s, loss=0.00155, v_num=0][A
Epoch 199: 100%|██████████████████████████████| 2/2 [00:00<00:00, 119.03it/s, loss=0.00155, v_num=0]

`Trainer.fit` stopped: `max_epochs=200` reached.


Epoch 199: 100%|██████████████████████████████| 2/2 [00:00<00:00, 113.16it/s, loss=0.00155, v_num=0]


In [24]:
trainer.test(model, data_module)

Testing DataLoader 0: 100%|██████████████████████████████████████████| 1/1 [00:00<00:00, 215.41it/s]


[{'test_loss': 0.0028741960413753986}]

## Results!

In [10]:
print("--- Best model ---")
print(callbacks[1].best_model_path)
print(callbacks[1].best_model_score)
metrics_file = Path(output_directory) / "metrics.pkl"
with open(str(metrics_file), "wb") as f:
    metrics = (trainer.callback_metrics, trainer.logged_metrics)
    pickle.dump(metrics, f)

print(f"Wrote metrics to {str(metrics_file)}")

--- Best model ---
output/lightning_logs/version_0/checkpoints/epoch=155-step=156.ckpt
tensor(0.0005)
Wrote metrics to output/metrics.pkl


In [11]:
propane = Molecule.from_smiles("CCC")

model.compute_property(propane)

tensor([[-0.0893],
        [-0.0795],
        [-0.0893],
        [ 0.0311],
        [ 0.0311],
        [ 0.0311],
        [ 0.0357],
        [ 0.0357],
        [ 0.0311],
        [ 0.0311],
        [ 0.0311]], grad_fn=<CatBackward0>)

In [12]:
propane.assign_partial_charges(partial_charge_method)
propane.partial_charges

0,1
Magnitude,[-0.09227999977090141 -0.08139999888160011 -0.09227999977090141  0.03215999969027259 0.03215999969027259 0.03215999969027259  0.036500000140883705 0.036500000140883705 0.03215999969027259  0.03215999969027259 0.03215999969027259]
Units,elementary_charge


In [14]:
mol = Molecule.from_smiles(test_smiles[0])

model.compute_property(mol)

tensor([[-0.0886],
        [-0.0774],
        [-0.0757],
        [-0.0756],
        [-0.0757],
        [-0.0774],
        [-0.0886],
        [ 0.0317],
        [ 0.0317],
        [ 0.0317],
        [ 0.0367],
        [ 0.0367],
        [ 0.0371],
        [ 0.0371],
        [ 0.0371],
        [ 0.0371],
        [ 0.0371],
        [ 0.0371],
        [ 0.0367],
        [ 0.0367],
        [ 0.0317],
        [ 0.0317],
        [ 0.0317]], grad_fn=<CatBackward0>)

In [16]:
mol.assign_partial_charges(partial_charge_method)
mol.partial_charges

0,1
Magnitude,[-0.09234000029771225 -0.07955999704806702 -0.07857999983041183  -0.07912000301091568 -0.07857999983041183 -0.07955999704806702  -0.09234000029771225 0.03220000085623368 0.03220000085623368  0.03220000085623368 0.03793999892861947 0.03793999892861947  0.039289999429298485 0.039289999429298485 0.03897999939711198  0.03897999939711198 0.039289999429298485 0.039289999429298485  0.03793999892861947 0.03793999892861947 0.03220000085623368  0.03220000085623368 0.03220000085623368]
Units,elementary_charge


In [19]:
list(model.compute_property(mol)), list(mol.partial_charges)

([tensor([-0.0886], grad_fn=<UnbindBackward0>),
  tensor([-0.0774], grad_fn=<UnbindBackward0>),
  tensor([-0.0757], grad_fn=<UnbindBackward0>),
  tensor([-0.0756], grad_fn=<UnbindBackward0>),
  tensor([-0.0757], grad_fn=<UnbindBackward0>),
  tensor([-0.0774], grad_fn=<UnbindBackward0>),
  tensor([-0.0886], grad_fn=<UnbindBackward0>),
  tensor([0.0317], grad_fn=<UnbindBackward0>),
  tensor([0.0317], grad_fn=<UnbindBackward0>),
  tensor([0.0317], grad_fn=<UnbindBackward0>),
  tensor([0.0367], grad_fn=<UnbindBackward0>),
  tensor([0.0367], grad_fn=<UnbindBackward0>),
  tensor([0.0371], grad_fn=<UnbindBackward0>),
  tensor([0.0371], grad_fn=<UnbindBackward0>),
  tensor([0.0371], grad_fn=<UnbindBackward0>),
  tensor([0.0371], grad_fn=<UnbindBackward0>),
  tensor([0.0371], grad_fn=<UnbindBackward0>),
  tensor([0.0371], grad_fn=<UnbindBackward0>),
  tensor([0.0367], grad_fn=<UnbindBackward0>),
  tensor([0.0367], grad_fn=<UnbindBackward0>),
  tensor([0.0317], grad_fn=<UnbindBackward0>),
  tens