# Training a Graph Neural Network with NAGL with multiple objectives

This notebook will go through the process of training a new Graph Neural Network (GNN) on a small dataset of alkanes with multiple objectives. Please see the `train-gnn-notebook` tutorial for more on what's happening under the hood.


## Imports

In [1]:
from pathlib import Path

import numpy as np

from openff.toolkit import Molecule
from openff.units import unit



## Create the model

First, let's specify the model features.

In [2]:
from openff.nagl.features import atoms

atom_features = (
    atoms.AtomicElement(categories=["C", "H"]), # Is the atom Carbon or Hydrogen?
    atoms.AtomConnectivity(), # Is the atom bonded to 1, 2, 3, or 4 other atoms?
    atoms.AtomAverageFormalCharge(), # What is the atom's mean formal charge over the molecule's tautomers?
    atoms.AtomHybridization(), # What is the hybridization of the atom?
    atoms.AtomInRingOfSize(ring_size=3), # Is the atom in a 3-membered ring?
    atoms.AtomInRingOfSize(ring_size=4), # Is the atom in a 4-membered ring?
    atoms.AtomInRingOfSize(ring_size=5), # Is the atom in a 5-membered ring?
    atoms.AtomInRingOfSize(ring_size=6), # Is the atom in a 6-membered ring?
)

We also need to specify the architecture of the GNN. We can make this as complicated as we like.

In [3]:
from openff.nagl.config.model import (
    ConvolutionLayer,
    ConvolutionModule,
)
from openff.nagl import GNNModel
from openff.nagl.nn.gcn import SAGEConvStack
from torch.nn import ReLU
from openff.nagl.nn.postprocess import ComputePartialCharges

single_convolution_layer = ConvolutionLayer(
    hidden_feature_size=128,  # 128 features per hidden convolution layer
    aggregator_type="mean",  # aggregate atom representations with mean
    activation_function="ReLU", # max(0, x) activation function for layer
    dropout=0.0, # no dropout
)

convolution_module = ConvolutionModule(
    architecture="SAGEConv", # GraphSAGE GCN
    layers=[single_convolution_layer] * 3, # 3 hidden convolution layers        
)

We then specify the readout module.

In [4]:
from openff.nagl.config.model import (
    ForwardLayer,
    ReadoutModule,
)

single_readout_layer = ForwardLayer(
    hidden_feature_size=128,  # 128 features per hidden convolution layer
    activation_function="ReLU", # max(0, x) activation function for layer
    dropout=0.0, # no dropout
)

normal_readout_module = ReadoutModule(
    pooling="atoms",
    layers=[single_readout_layer] * 4, # 4 internal readout layers
    # calculate charges with charge equilibration scheme from
    # electronegativity and hardness
    postprocess="compute_partial_charges"
)
regularised_readout_module = ReadoutModule(
    pooling="atoms",
    layers=[single_readout_layer] * 4, # 4 internal readout layers
    # calculate charges with charge equilibration scheme from
    # electronegativity, hardness, and an initial charge prediction
    postprocess="regularized_compute_partial_charges"
)


Now we can put them together in a full `ModelConfig`. This can be passed to create a `GNNModel`. A model can have multiple readouts that derive different properties from the convolution representation, so each readout module is specified in a dictionary with a label.


Here, the [`GNNModel`] class represents all the hyperparameters for a model, but after we train it the same object will store weights as well.






In [5]:
from openff.nagl.config.model import ModelConfig

model_config = ModelConfig(
    atom_features=atom_features,
    bond_features=[],
    convolution=convolution_module,
    readouts={
        "predicted-am1bcc-charges": normal_readout_module,
        "predicted-am1-charges": regularised_readout_module
    }
)

## Put together our datasets

We need to set up three datasets hers:

- **training**: Data the model is trained against

- **validation**: Data used to validate the model as it is trained

- **tests**: Data used to test that the final model is good

In this example, we'll use a collection of ten molecules for training. We'll also build a test/validation dataset of 3 molecules that are not in the training set.

We can use the [`LabelledDataset`] class to generate training data that is saved in the `training_data` directory (or use `pyarrow` directly). First we can generate the dataset from SMILES:

In [6]:
from openff.nagl.label.dataset import LabelledDataset

training_alkanes = [
    'C',
     'CC',
     'CCC',
     'CCCC',
     'CC(C)C',
     'CCCCC',
     'CC(C)CC',
     'CCCCCC',
     'CC(C)CCC',
     'CC(CC)CC',
]

training_dataset = LabelledDataset.from_smiles(
    "training_data",
    training_alkanes,
    mapped=False,
    overwrite_existing=True,
)
training_dataset.to_pandas()

Unnamed: 0,mapped_smiles
0,[H:2][C:1]([H:3])([H:4])[H:5]
1,[H:3][C:1]([H:4])([H:5])[C:2]([H:6])([H:7])[H:8]
2,[H:4][C:1]([H:5])([H:6])[C:2]([H:7])([H:8])[C:...
3,[H:5][C:1]([H:6])([H:7])[C:2]([H:8])([H:9])[C:...
4,[H:5][C:1]([H:6])([H:7])[C:2]([H:8])([C:3]([H:...
5,[H:6][C:1]([H:7])([H:8])[C:2]([H:9])([H:10])[C...
6,[H:15][C:5]([H:16])([H:17])[C:4]([H:13])([H:14...
7,[H:7][C:1]([H:8])([H:9])[C:2]([H:10])([H:11])[...
8,[H:18][C:6]([H:19])([H:20])[C:5]([H:16])([H:17...
9,[H:13][C:4]([H:14])([H:15])[C:3]([H:11])([H:12...


Below we specify label functions to label our molecules with the information that we will use in training and testing. Each argument is specified and annotated to explain their purpose; however, all label functions can be instantiated with default arguments (e.g. `LabelConformers()`) unless you need specific column names or arguments (e.g. changing the charge method).

**Note: the ESP label function requires `openff-recharge` to be installed.**

In [7]:
from openff.nagl.label.labels import (
    LabelConformers,
    LabelCharges,
    LabelMultipleDipoles,
    LabelMultipleESPs,
)
import openff.recharge

# generate ELF conformers
label_conformers = LabelConformers(
    # create a new 'conformers' with output conformers
    conformer_column="conformers",
    # create a new 'n_conformers' with number of conformers
    n_conformer_column="n_conformers",
    n_conformer_pool=500, # initially generate 500 conformers
    n_conformers=10, # prune to max 10 conformers
    rms_cutoff=0.05,
)

# generate AM1 charges
label_am1_charges = LabelCharges(
    charge_method="am1-mulliken", # AM1
    # use previously generate conformers instead of new ones
    use_existing_conformers=True,
    # use the 'conformers' column as input for charge assignment
    conformer_column="conformers",
    # write generated charges to 'target-am1-charges' column
    charge_column="target-am1-charges",
)

# generate AM1-BCC charges
label_am1bcc_charges = LabelCharges(
    charge_method="am1bcc", # AM1BCC
    # use previously generate conformers instead of new ones
    use_existing_conformers=True,
    # use the 'conformers' column as input for charge assignment
    conformer_column="conformers",
    # write generated charges to 'target-am1bcc-charges' column
    charge_column="target-am1bcc-charges",
)

label_am1bcc_dipoles = LabelMultipleDipoles(
    # use the 'conformers' column as input to calculate dipole moments
    conformer_column="conformers",
    # use the 'n_conformers' column as input
    n_conformer_column="n_conformers",
    # use the "target-am1bcc-charges" column as input to calculate dipole moments
    charge_column="target-am1bcc-charges",
    # write calculated dipoles to 'target-am1bcc-dipoles' column
    dipole_column="target-am1bcc-dipoles",
)

label_am1bcc_esps = LabelMultipleESPs(
    # use the 'conformers' column as input to calculate ESPs
    conformer_column="conformers",
    # use the 'n_conformers' column as input
    n_conformer_column="n_conformers",
    # use the "target-am1bcc-charges" column as input to calculate ESPS
    charge_column="target-am1bcc-charges",
    # generate new grids and inverse distances to points
    use_existing_inverse_distances=False,
    # write inverse distances from conformer to surface to this column
    inverse_distance_matrix_column="grid_inverse_distances",
    # write number of grid points for each surface to this column
    grid_length_column="esp_lengths",
    # write calculated ESPs to 'esps' column
    esp_column="esps",
)

Below we apply the label functions to actually generate the labels. The order matters, as later label functions use the output of earlier ones.

In [8]:
labellers = [
    label_conformers, # generate initial conformers,
    label_am1_charges,
    label_am1bcc_charges,
    label_am1bcc_dipoles,
    label_am1bcc_esps,
]

training_dataset.apply_labellers(labellers)
training_dataset.to_pandas()

Unnamed: 0,mapped_smiles,conformers,n_conformers,target-am1-charges,target-am1bcc-charges,target-am1bcc-dipoles,esp_lengths,grid_inverse_distances,esps
0,[H:2][C:1]([H:3])([H:4])[H:5],"[-6.580352783203125e-05, -6.109476089477539e-0...",1,"[-0.2658799886703491, 0.06646999716758728, 0.0...","[-0.10868000239133835, 0.027170000597834587, 0...","[8.93940945267957e-06, 8.315918032519853e-07, ...",[386],"[0.22234336592604098, 0.23155396222058547, 0.1...","[-0.0006674971296186626, -0.000637023458232497..."
1,[H:3][C:1]([H:4])([H:5])[C:2]([H:6])([H:7])[H:8],"[0.8151131868362427, -0.5383363366127014, 0.49...",1,"[-0.21174000017344952, -0.21174000017344952, 0...","[-0.09384000208228827, -0.09384000208228827, 0...","[-1.368494064180048e-06, 7.544429469624747e-06...",[496],"[0.22234336592604098, 0.2127901145925834, 0.19...","[-0.0017383392833108273, -0.000694559780806171..."
2,[H:4][C:1]([H:5])([H:6])[C:2]([H:7])([H:8])[C:...,"[-0.5531941056251526, -0.27858617901802063, 0....",1,"[-0.21018000082536178, -0.15999999777837234, -...","[-0.09227999977090141, -0.08139999888160011, -...","[3.7068362594441795e-05, 0.0005890870862074807...",[596],"[0.22234336592604098, 0.18767734846586412, 0.1...","[-0.0015067085988751464, -0.001471156784857612..."
3,[H:5][C:1]([H:6])([H:7])[C:2]([H:8])([H:9])[C:...,"[0.4047875702381134, -0.7297815084457397, 0.70...",1,"[-0.21003000438213348, -0.15905000269412994, -...","[-0.09212999844125339, -0.08044999891093799, -...","[0.00021953966206779418, 0.0004791439598072974...",[700],"[0.22234336592604098, 0.18252362019636445, 0.2...","[-0.0012686832717835857, -0.000767904450083205..."
4,[H:5][C:1]([H:6])([H:7])[C:2]([H:8])([C:3]([H:...,"[0.6623112559318542, -0.37933579087257385, -0....",1,"[-0.20747000138674462, -0.10981000374470438, -...","[-0.08957000076770782, -0.07050999999046326, -...","[0.0006696671713967151, 0.00041734828298395366...",[670],"[0.22234336592604098, 0.1385775500313658, 0.12...","[0.0003751502228112446, 3.505781229868471e-05,..."
5,[H:6][C:1]([H:7])([H:8])[C:2]([H:9])([H:10])[C...,"[-0.5110070705413818, 0.9039239883422852, -0.3...",1,"[-0.21004000306129456, -0.15812000632286072, -...","[-0.09213999658823013, -0.07952000200748444, -...","[0.0012899358122924753, -0.0007518562134603557...",[777],"[0.22234336592604098, 0.13571425014671695, 0.1...","[0.00011077201387450254, 0.0005246440916685611..."
6,[H:15][C:5]([H:16])([H:17])[C:4]([H:13])([H:14...,"[2.668698787689209, -1.6484040021896362, -1.37...",1,"[-0.20766000405830495, -0.10704000250381582, -...","[-0.0897599982426447, -0.06774000100353185, -0...","[0.0006707797039303956, -0.0003757642726688048...",[762],"[0.22234336592604098, 0.1764458906335293, 0.12...","[-0.0014334151785553792, -0.000199164611945194..."
7,[H:7][C:1]([H:8])([H:9])[C:2]([H:10])([H:11])[...,"[-0.7496212124824524, -0.7225577235221863, -0....",1,"[-0.21021999344229697, -0.15823000594973563, -...","[-0.0923200011253357, -0.0796300008893013, -0....","[0.00042592607741254174, -0.000578177091797665...",[889],"[0.22234336592604098, 0.1357225286015051, 0.11...","[0.00011405104641144023, 0.0005901320688421231..."
8,[H:18][C:6]([H:19])([H:20])[C:5]([H:16])([H:17...,"[1.0540275573730469, 4.25633430480957, -2.8048...",1,"[-0.208649992197752, -0.1059999980032444, -0.2...","[-0.09075000137090683, -0.06669999659061432, -...","[0.0031352303428189288, 0.0006422425590680431,...",[869],"[0.22234336592604098, 0.1729422649620562, 0.12...","[-0.0016350722063982184, 0.0001077856251775464..."
9,[H:13][C:4]([H:14])([H:15])[C:3]([H:11])([H:12...,"[0.5423022508621216, -3.9825656414031982, -0.8...",1,"[-0.2068299949169159, -0.10380999743938446, -0...","[-0.08893000297248363, -0.06451000235974788, -...","[0.006164249258937882, -0.003163944439926847, ...",[837],"[0.22234336592604098, 0.16111913382752438, 0.1...","[-0.000896426434791848, -0.0030655746161733586..."


### Building a test dataset

To augment the provided training set, we'll quickly prepare a second dataset for testing and validation. We use the same label functions:

In [9]:
from openff.nagl.label.labels import LabelCharges

# Choose the molecules to put in this dataset
# Note that these molecules aren't in the training dataset!
test_smiles = [
    "CCCCCCC",
    "CC(C)C(C)C",
    "CC(C)(C)C",
]

test_dataset = LabelledDataset.from_smiles(
    "my_first_test_dataset",  # path to save to
    test_smiles,
    mapped=False,
    overwrite_existing=True,
)

test_dataset.apply_labellers(labellers)
test_dataset.to_pandas()

Unnamed: 0,mapped_smiles,conformers,n_conformers,target-am1-charges,target-am1bcc-charges,target-am1bcc-dipoles,esp_lengths,grid_inverse_distances,esps
0,[H:8][C:1]([H:9])([H:10])[C:2]([H:11])([H:12])...,"[1.3991600275039673, 1.056309461593628, -1.841...",2,"[-0.21044000043817188, -0.15767000421233798, -...","[-0.09253999882418176, -0.07907000103074571, -...","[-0.0004773724634927201, -0.001530751792649470...","[999, 976]","[0.22234336592604098, 0.21183263152347134, 0.1...","[-0.00251609579612631, -0.001154553480499424, ..."
1,[H:7][C:1]([H:8])([H:9])[C:2]([H:10])([C:3]([H...,"[0.5108481645584106, -0.7951530814170837, 0.55...",1,"[-0.20714999884366989, -0.10120000094175338, -...","[-0.08924999833106995, -0.061900001019239426, ...","[-0.0031004916088699375, -0.003375164159762522...",[820],"[0.22234336592604098, 0.17268756186264236, 0.1...","[-0.002090256191838929, -0.002086751713862172,..."
2,[H:6][C:1]([H:7])([H:8])[C:2]([C:3]([H:9])([H:...,"[-0.613831102848053, -0.1246185302734375, 0.89...",1,"[-0.20284000192494953, -0.059600000872331506, ...","[-0.08494000114938792, -0.05959999911925372, -...","[7.753205436847321e-06, 3.927593248914327e-06,...",[748],"[0.22234336592604098, 0.16031090135271214, 0.1...","[-0.0008350871636610829, -0.001165141365137959..."


### Curating our data module

Now we assemble our datasets into a `DataConfig`. For each `DatasetConfig`, we need to specify the targets we are choosing to fit. A `Target` is what we used to construct the objective function and calculate loss. Below we:
- fit the `predicted-am1-charges` property directly to the labelled `target-am1-charges`
- fit the `predicted-am1bcc-charges` property to a combined objective of:
  - charge RMSE
  - dipole moments
  - ESP targets
 
For the physical properties, we need additional information (the `conformers`, `n_conformers`, ... columns) to calculate the dipole moments and ESPs for comparison. Each of these has been annotated. The `target_label` always refers to the column in the input dataset that is the property we are comparing.

In [10]:
from openff.nagl.config.data import DatasetConfig, DataConfig
from openff.nagl.training.metrics import RMSEMetric
from openff.nagl.training.loss import ReadoutTarget, MultipleDipoleTarget, ESPTarget


am1_charge_rmse_target = ReadoutTarget(
    metric=RMSEMetric(),  # use RMSE to calculate loss
    target_label="target-am1-charges", # column to use from data as reference target
    prediction_label="predicted-am1-charges", # readout value to compare to target
    denominator=1.0, # denominator to normalise loss -- important for multi-target objectives
    weight=1.0, # how much to weight the loss -- important for multi-target objectives
)

am1bcc_charge_rmse_target = ReadoutTarget(
    metric=RMSEMetric(),  # use RMSE to calculate loss
    target_label="target-am1bcc-charges", # column to use from data as reference target
    prediction_label="predicted-am1bcc-charges", # readout value to compare to target
    denominator=0.001, # denominator to normalise loss -- important for multi-target objectives
    weight=1.0, # how much to weight the loss -- important for multi-target objectives
)

am1bcc_dipole_target = MultipleDipoleTarget(
    metric=RMSEMetric(),
    target_label="target-am1bcc-dipoles", # column to use from input data as reference target
    charge_label="predicted-am1bcc-charges", # readout charge value to calculate dipoles with
    conformation_column="conformers", # input data to use for calculating dipoles
    n_conformation_column="n_conformers", # input data to use for calculating dipoles
    denominator=0.01,
    weight=1.0
)

am1bcc_esp_target = ESPTarget(
    metric=RMSEMetric(),
    target_label="esps", # column to use from input data as reference target
    charge_label="predicted-am1bcc-charges", # readout charge value to calculate ESPs with
    inverse_distance_matrix_column="grid_inverse_distances", # input data to use to calculate ESPs
    esp_length_column="esp_lengths", # input data to use to calculate ESPs
    n_esp_column="n_conformers", # input data to use to calculate ESPs
    denominator=0.001,
    weight=1.0
)

Now we combine each of these targets into each `DatasetConfig`.

In [11]:
targets = [
    am1_charge_rmse_target,
    am1bcc_charge_rmse_target,
    am1bcc_dipole_target,
    am1bcc_esp_target,
]
    

training_dataset_config = DatasetConfig(
    sources=["training_data"],
    targets=targets,
    batch_size=1000,
)

test_dataset_config = validation_dataset_config = DatasetConfig(
    sources=["my_first_test_dataset"],
    targets=targets,
    batch_size=1000,
)

data_config = DataConfig(
    training=training_dataset_config,
    validation=validation_dataset_config,
    test=test_dataset_config
)

## Train the model

We've prepared our model architecture and our training, validation and test data; now we just need to fit the model! To do this, we need to specify optimization settings with a `OptimizerConfig`, and then put everything together in a `TrainingConfig`.

In [12]:
from openff.nagl.config.optimizer import OptimizerConfig
from openff.nagl.config.training import TrainingConfig

optimizer_config = OptimizerConfig(
    optimizer="Adam",
    learning_rate=0.001,
)

training_config = TrainingConfig(
    model=model_config,
    data=data_config,
    optimizer=optimizer_config
)

In [13]:
from openff.nagl.training.training import TrainingGNNModel, DGLMoleculeDataModule

training_model = TrainingGNNModel(training_config)
data_module = DGLMoleculeDataModule(training_config)

To properly fit the model, we use the [`Trainer`] class from PyTorch Lightning. This allows us to configure how data and progress are stored and reported using callbacks. The [`fit()`] method trains and validates against the data module we provide it: 

[`Trainer`]: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html
[`fit()`]: https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html#pytorch_lightning.trainer.trainer.Trainer.fit

In [14]:
from pytorch_lightning import Trainer

trainer = Trainer(max_epochs=200)

trainer.progress_bar_callback.disable()
trainer.checkpoint_callback.monitor = "val/loss"

trainer.fit(
    training_model,
    datamodule=data_module
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 100%|███████████████████████| 10/10 [00:00<00:00, 116.92it/s][A
Featurizing dataset: 1it [00:00,  7.84it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 100%|██████████████████████████| 3/3 [00:00<00:00, 95.75it/s][A
Featurizing dataset: 1it [00:00, 17.53it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 100%|██████████████████████████| 3/3 [00:00<00:00, 89.77it/s][A
Featurizing dataset: 1it [00:00, 16.09it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 100%|███████████████████████| 10/10 [00:00<00:00, 132.26it/s][A
Featurizing dataset: 1it [00:00,  8.60it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 100%|██████████████████████████| 3/3 [00:00<00:00, 96.33it/s][A
Featurizing dataset: 1it [00:00, 17.42it/s]
Featurizing d

## Results!

We can use the `Trainer` object's [`test()`] method to evaluate the model against our test data:

[`test()`]: https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html#pytorch_lightning.trainer.trainer.Trainer.test

In [15]:
trainer.test(training_model, data_module)

Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 100%|███████████████████████| 10/10 [00:00<00:00, 137.43it/s][A
Featurizing dataset: 1it [00:00,  8.79it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 100%|█████████████████████████| 3/3 [00:00<00:00, 108.08it/s][A
Featurizing dataset: 1it [00:00, 18.92it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 100%|██████████████████████████| 3/3 [00:00<00:00, 88.46it/s][A
Featurizing dataset: 1it [00:00, 16.13it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 100%|███████████████████████| 10/10 [00:00<00:00, 133.77it/s][A
Featurizing dataset: 1it [00:00,  8.68it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 100%|██████████████████████████| 3/3 [00:00<00:00, 98.38it/s][A
Featurizing dataset: 1it [00:00, 17.62it/s]
Featurizing dataset: 0it [00:00, ?it/s]
Featurizing batch: 100%|█████████████████████████| 3/3 [00:00<00:00, 103.55it/s][A
Featurizing dataset: 1it [00:00, 18.

[{'test/target-am1-charges/readout/rmse/1.0/1.0': 0.008478229865431786,
  'test/target-am1bcc-charges/readout/rmse/1.0/0.001': 3.4283690452575684,
  'test/target-am1bcc-dipoles/multi_dipole/rmse/1.0/0.01': 0.2450200766324997,
  'test/esps/esp/rmse/1.0/0.001': 0.10448392480611801,
  'test/loss': 3.786351203918457}]

We can isolate the model itself from all the training requirements:

In [16]:
model = training_model.model

Octane isn't in any of our data, so the model hasn't seen it yet! We can predict its partial charges with the [`compute_property()`] method:

[`compute_property()`]: https://docs.openforcefield.org/projects/nagl/en/stable/api/generated/openff.nagl.GNNModel.html#openff.nagl.GNNModel.compute_property

In [18]:
octane = Molecule.from_smiles("CCCCCCCC")

am1bcc_charges = model.compute_property(octane, readout_name="predicted-am1bcc-charges")

And we can compare that to the AM1BCC partial charges produced by the OpenFF Toolkit:

In [19]:
octane.assign_partial_charges("am1bcc")
octane.partial_charges

0,1
Magnitude,[-0.09216000225681525 -0.07999000039238197 -0.079319999481623  -0.07835999962228996 -0.07835999962228996 -0.079319999481623  -0.07999000039238197 -0.09207999792236549 0.03245000082712907  0.03245000082712907 0.03245000082712907 0.03792999971371431  0.03792999971371431 0.038880000893886275 0.038880000893886275  0.03940999794464845 0.03940999794464845 0.03940999794464845  0.03940999794464845 0.038880000893886275 0.038880000893886275  0.03792999971371431 0.03792999971371431 0.03245000082712907  0.03245000082712907 0.03245000082712907]
Units,elementary_charge


In [20]:
prediction = am1bcc_charges * unit.elementary_charge
np.abs(prediction - octane.partial_charges)

0,1
Magnitude,[0.0013413569675042036 0.0009164056525780562 0.00014290127616661819  0.0009294814215256575 0.0009295112238480452 0.00014290127616661819  0.0009163907514168623 0.0014213762031151655 0.0007257954432414102  0.0007257656409190225 0.0007257954432414102 0.0006569744302676248  0.0006570042325900124 0.00023466807145338525 0.00023469787377577295  0.00029532897930879126 0.00029529917698640357 0.00029532897930879126  0.00029529917698640357 0.00023466807145338525 0.00023469787377577295  0.0006569744302676248 0.0006570042325900124 0.0007257656409190225  0.0007257954432414102 0.0007257656409190225]
Units,elementary_charge


All within 0.002 elementary charge units of true AM1BCC charges! Not too bad!

Similarly, looking at AM1 charges:

In [21]:
am1_charges = model.compute_property(octane, readout_name="predicted-am1-charges")
am1_charges

array([-0.19512036, -0.14707863, -0.14554948, -0.1456733 , -0.14567327,
       -0.14554948, -0.1470786 , -0.19512033,  0.06604499,  0.06604499,
        0.06604499,  0.07206178,  0.07206178,  0.0727908 ,  0.0727908 ,
        0.0727908 ,  0.0727908 ,  0.0727908 ,  0.0727908 ,  0.0727908 ,
        0.0727908 ,  0.07206178,  0.07206178,  0.06604499,  0.06604499,
        0.06604499], dtype=float32)

In [22]:
octane.assign_partial_charges("am1-mulliken")
octane.partial_charges

0,1
Magnitude,[-0.21005999984649512 -0.1585900032749543 -0.15792000236419532  -0.15695999505428168 -0.15695999505428168 -0.15792000236419532  -0.1585900032749543 -0.20997999551204535 0.07175000069233087  0.07175000069233087 0.07175000069233087 0.0772299995789161  0.0772299995789161 0.07818000075908807 0.07818000075908807  0.07870999780985025 0.07870999780985025 0.07870999780985025  0.07870999780985025 0.07818000075908807 0.07818000075908807  0.0772299995789161 0.0772299995789161 0.07175000069233087  0.07175000069233087 0.07175000069233087]
Units,elementary_charge


In [23]:
am1_prediction = am1_charges * unit.elementary_charge
np.abs(am1_prediction - octane.partial_charges)

0,1
Magnitude,[0.014939635418928593 0.011511369966543644 0.012370526217497319  0.01128669025806281 0.011286720060385197 0.012370526217497319  0.011511399768866032 0.014859660886801213 0.00570501444431451  0.00570501444431451 0.00570501444431451 0.005168222464047939  0.005168222464047939 0.005389199233972103 0.005389199233972103  0.005919196284734279 0.005919196284734279 0.005919196284734279  0.005919196284734279 0.005389199233972103 0.005389199233972103  0.005168222464047939 0.005168222464047939 0.00570501444431451  0.00570501444431451 0.00570501444431451]
Units,elementary_charge


This is slightly less accurate (to within 0.02 elementary charge), possibly because AM1 charges were only fit to the charges directly, with none of the physical properties.

## Saving and loading our model

We can save the final model with the `model.save()` method. This'll let us store it for later.

In [24]:
model.save("trained_alkane_model.pt")

When we want it again, we can use the `GNNModel.load()` method:

In [25]:
model_from_disk = GNNModel.load("trained_alkane_model.pt")
model_from_disk.compute_property(octane, readout_name="predicted-am1bcc-charges")

array([-0.09350136, -0.08090641, -0.0794629 , -0.07928948, -0.07928951,
       -0.0794629 , -0.08090639, -0.09350137,  0.0331758 ,  0.03317577,
        0.0331758 ,  0.03858697,  0.038587  ,  0.03911467,  0.0391147 ,
        0.03911467,  0.0391147 ,  0.03911467,  0.0391147 ,  0.03911467,
        0.0391147 ,  0.03858697,  0.038587  ,  0.03317577,  0.0331758 ,
        0.03317577], dtype=float32)