# Train a GNN directly to an electric field

To execute this example fully, the following packages are required.

* openff-nagl
* openff-recharge
* openff-qcsubmit
* psi4

However, if you wish to just follow along the training part without first creating the training datasets yourself, you can get away with just `openff-nagl` installed and simply load the training/validation data from the provided `.parquet` files. The commands are provided at the end of the "Generate and format training data" section, but commented out.

In [1]:
import tqdm

from qcportal import PortalClient
from openff.units import unit

from openff.toolkit import Molecule
from openff.qcsubmit.results import BasicResultCollection
from openff.recharge.esp.storage import MoleculeESPRecord
from openff.recharge.esp.qcresults import from_qcportal_results
from openff.recharge.grids import MSKGridSettings
from openff.recharge.utilities.geometry import compute_vector_field

import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np

import torch



## Generate and format training data


### Downloading from QCArchive
First, we will create training data. We'll download a smaller training set for the purposes of this example.

In [2]:
qc_client = PortalClient("https://api.qcarchive.molssi.org:443", cache_dir=".")

# download dataset from QCPortal
br_esps_collection = BasicResultCollection.from_server(
    client=qc_client,
    datasets="OpenFF multi-Br ESP Fragment Conformers v1.1",
    spec_name="HF/6-31G*",
)

records_and_molecules = br_esps_collection.to_records()

### Converting to MoleculeESPRecords

Now we convert to OpenFF Recharge records and compute the ESPs and electric fields of each molecule.

In [3]:
# Create OpenFF Recharge MoleculeESPRecords
grid_settings = MSKGridSettings()

# this can take a while; set records_and_molecules[:10]
# to only use the first 10
molecule_esp_records = [
    from_qcportal_results(
        qc_result=qcrecord,
        qc_molecule=qcrecord.molecule,
        qc_keyword_set=qcrecord.specification.keywords,
        grid_settings=grid_settings,
        compute_field=True
    )
    for qcrecord, _ in tqdm.tqdm(records_and_molecules[:50])
]

  0%|          | 0/50 [00:00<?, ?it/s]

  2%|‚ñè         | 1/50 [00:09<07:56,  9.73s/it]

  4%|‚ñç         | 2/50 [00:18<07:32,  9.42s/it]

  6%|‚ñå         | 3/50 [00:26<06:45,  8.62s/it]

  8%|‚ñä         | 4/50 [00:34<06:15,  8.16s/it]

 10%|‚ñà         | 5/50 [00:40<05:38,  7.51s/it]

 12%|‚ñà‚ñè        | 6/50 [00:50<06:13,  8.49s/it]

 14%|‚ñà‚ñç        | 7/50 [00:58<05:47,  8.09s/it]

 16%|‚ñà‚ñå        | 8/50 [01:05<05:29,  7.85s/it]

 18%|‚ñà‚ñä        | 9/50 [01:16<06:02,  8.85s/it]

 20%|‚ñà‚ñà        | 10/50 [01:23<05:31,  8.29s/it]

 22%|‚ñà‚ñà‚ñè       | 11/50 [01:31<05:23,  8.30s/it]

 24%|‚ñà‚ñà‚ñç       | 12/50 [01:38<04:53,  7.74s/it]

 26%|‚ñà‚ñà‚ñå       | 13/50 [01:43<04:22,  7.09s/it]

 28%|‚ñà‚ñà‚ñä       | 14/50 [01:49<03:59,  6.65s/it]

 30%|‚ñà‚ñà‚ñà       | 15/50 [01:57<04:05,  7.00s/it]

 32%|‚ñà‚ñà‚ñà‚ñè      | 16/50 [02:06<04:18,  7.61s/it]

 34%|‚ñà‚ñà‚ñà‚ñç      | 17/50 [02:11<03:49,  6.95s/it]

 36%|‚ñà‚ñà‚ñà‚ñå      | 18/50 [02:17<03:27,  6.49s/it]

 38%|‚ñà‚ñà‚ñà‚ñä      | 19/50 [02:20<02:47,  5.41s/it]

 40%|‚ñà‚ñà‚ñà‚ñà      | 20/50 [02:30<03:24,  6.82s/it]

 42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 21/50 [02:40<03:51,  7.97s/it]

 44%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 22/50 [02:50<03:59,  8.54s/it]

 46%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 23/50 [02:59<03:55,  8.72s/it]

 48%|‚ñà‚ñà‚ñà‚ñà‚ñä     | 24/50 [03:09<03:56,  9.08s/it]

 50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 25/50 [03:21<04:04,  9.79s/it]

 52%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè    | 26/50 [03:27<03:27,  8.66s/it]

 54%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 27/50 [03:35<03:18,  8.62s/it]

 56%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 28/50 [03:44<03:08,  8.55s/it]

 58%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä    | 29/50 [03:53<03:03,  8.73s/it]

 60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 30/50 [04:02<02:58,  8.95s/it]

 62%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè   | 31/50 [04:08<02:29,  7.88s/it]

 64%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç   | 32/50 [04:16<02:22,  7.93s/it]

 66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 33/50 [04:24<02:14,  7.92s/it]

 68%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä   | 34/50 [04:30<02:01,  7.58s/it]

 70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 35/50 [04:38<01:55,  7.71s/it]

 72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 36/50 [04:46<01:48,  7.73s/it]

 74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç  | 37/50 [04:55<01:43,  7.98s/it]

 76%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 38/50 [05:04<01:40,  8.36s/it]

 78%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä  | 39/50 [05:14<01:37,  8.86s/it]

 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 40/50 [05:24<01:31,  9.15s/it]

 82%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè | 41/50 [05:33<01:23,  9.27s/it]

 84%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç | 42/50 [05:43<01:14,  9.35s/it]

 86%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå | 43/50 [05:52<01:05,  9.30s/it]

 88%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä | 44/50 [05:59<00:51,  8.56s/it]

 90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 45/50 [06:08<00:43,  8.77s/it]

 92%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè| 46/50 [06:17<00:35,  8.81s/it]

 94%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç| 47/50 [06:26<00:26,  8.81s/it]

 96%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 48/50 [06:33<00:16,  8.26s/it]

 98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 49/50 [06:42<00:08,  8.57s/it]

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [06:52<00:00,  9.04s/it]

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [06:52<00:00,  8.26s/it]




### Convert to PyArrow dataset

NAGL reads in and trains to data from [PyArrow tables](https://arrow.apache.org/docs/python/getstarted.html#creating-arrays-and-tables). Below we do some conversion of each electric field to fit a basic GeneralLinearFit target, which fits the equation Ax = b. To avoid carrying too many data points around, we can do some postprocessing by first flattening the electric field matrix from 3 dimensions to 2, then multiplying by the transpose of A.

$$\mathbf{A}\vec{x} = \vec{b}$$
$$\mathbf{A^{T}}\mathbf{A}\vec{x} = \mathbf{A^{T}}\vec{b}$$

In [4]:
pyarrow_entries = []
for molecule_esp_record in tqdm.tqdm(molecule_esp_records):
    electric_field = molecule_esp_record.electric_field # in atomic units
    grid = molecule_esp_record.grid_coordinates * unit.angstrom
    conformer = molecule_esp_record.conformer * unit.angstrom

    vector_field = compute_vector_field(
        conformer.m_as(unit.bohr),  # shape: M x 3
        grid.m_as(unit.bohr),  # shape: N x 3
    ) # N x 3 x M

    # postprocess so we're not carrying around millions of floats
    # firstly flatten out N x 3 -> 3N x M
    vector_field_2d = np.concatenate(vector_field, axis=0)
    electric_field_1d = np.concatenate(electric_field, axis=0)

    # now multiply by vector_field_2d's transpose
    new_precursor_matrix = vector_field_2d.T @ vector_field_2d
    new_field_vector = vector_field_2d.T @ electric_field_1d

    n_atoms = conformer.shape[0]
    assert new_precursor_matrix.shape == (n_atoms, n_atoms)
    assert new_field_vector.shape == (n_atoms,)

    # create entry. These columns are essential
    # mapped_smiles is essential for every target
    entry = {
        "mapped_smiles": molecule_esp_record.tagged_smiles,
        "precursor_matrix": new_precursor_matrix.flatten().tolist(),
        "prediction_vector": new_field_vector.tolist()
    }
    pyarrow_entries.append(entry)


# arbitrarily split into training and validation datasets
training_pyarrow_entries = pyarrow_entries[:-10]
validation_pyarrow_entries = pyarrow_entries[-10:]

training_table = pa.Table.from_pylist(training_pyarrow_entries)
validation_table = pa.Table.from_pylist(validation_pyarrow_entries)
training_table

  0%|          | 0/50 [00:00<?, ?it/s]

 54%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 27/50 [00:00<00:00, 265.88it/s]

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50/50 [00:00<00:00, 253.88it/s]




pyarrow.Table
mapped_smiles: string
precursor_matrix: list<item: double>
  child 0, item: double
prediction_vector: list<item: double>
  child 0, item: double
----
mapped_smiles: [["[H:1][N:2]1[C:3]([Br:4])=[N:5][c:6]2[n:7][c:8]([Br:9])[c:10]([Br:11])[n:12][c:13]21","[H:1][N:2]([H:3])[C:4]1([C:5](=[O:6])[O-:7])[C:8]([H:9])([H:10])[C:11]2([H:12])[C:13]([Br:14])([Br:15])[C:16]2([H:17])[C:18]1([H:19])[H:20]","[H:1][C:2]1([H:3])[C:4]2([H:5])[C:6]([Br:7])([Br:8])[C:9]2([H:10])[C:11]([H:12])([H:13])[C:14]1([C:15](=[O:16])[O-:17])[N+:18]([H:19])([H:20])[H:21]","[H:1][N:2]1[C:3]([H:4])([H:5])[C:6]2([H:7])[C:8]([Br:9])([Br:10])[C:11]2([H:12])[C:13]1([H:14])[C:15](=[O:16])[O-:17]","[H:1][C:2]12[C:3]([Br:4])([Br:5])[C:6]1([H:7])[C:8]([H:9])([C:10](=[O:11])[O-:12])[N+:13]([H:14])([H:15])[C:16]2([H:17])[H:18]",...,"[H:1][c:2]1[c:3]([Br:4])[c:5]([H:6])[c:7]2[c:8]([c:9]1[Br:10])[C:11]([H:12])([H:13])[C:14]([H:15])([H:16])[C:17]2([H:18])[H:19]","[H:1][c:2]1[c:3]([H:4])[c:5]2[c:6]([c:7]([Br:8])[c:9]1[B

In [5]:
pq.write_table(training_table, "training_dataset_table.parquet")
pq.write_table(validation_table, "validation_dataset_table.parquet")

# # to read back in -- note, the files saved here give the full dataset, not the 50 record subset
# training_table = pq.read_table("training_dataset_table.parquet")
# validation_table = pq.read_table("validation_dataset_table.parquet")

## Set up for training a GNN

In [6]:
from openff.nagl.config import (
    TrainingConfig,
    OptimizerConfig,
    ModelConfig,
    DataConfig
)
from openff.nagl.config.model import (
    ConvolutionModule, ReadoutModule,
    ConvolutionLayer, ForwardLayer,
)
from openff.nagl.config.data import DatasetConfig
from openff.nagl.training.training import TrainingGNNModel
from openff.nagl.features.atoms import (
    AtomicElement,
    AtomConnectivity,
    AtomInRingOfSize,
    AtomAverageFormalCharge,
)

from openff.nagl.training.loss import GeneralLinearFitTarget

### Defining the training config

#### Defining a ModelConfig

First we define a ModelConfig.
This can be done in Python, but in practice it is probably easier to define the model in a YAML file and load it with `ModelConfig.from_yaml`.

In [7]:
atom_features = [
    AtomicElement(categories=["H", "C", "N", "O", "F", "Br", "S", "P", "I"]),
    AtomConnectivity(categories=[1, 2, 3, 4, 5, 6]),
    AtomInRingOfSize(ring_size=3),
    AtomInRingOfSize(ring_size=4),
    AtomInRingOfSize(ring_size=5),
    AtomInRingOfSize(ring_size=6),
    AtomAverageFormalCharge(),
]

# define our convolution module
convolution_module = ConvolutionModule(
    architecture="SAGEConv",
    # construct 6 layers with dropout 0 (default),
    # hidden feature size 512, and ReLU activation function
    # these layers can also be individually specified,
    # but we just duplicate the layer 6 times for identical layers
    layers=[
        ConvolutionLayer(
            hidden_feature_size=512,
            activation_function="ReLU",
            aggregator_type="mean"
        )
    ] * 6,
)

# define our readout module/s
# multiple are allowed but let's focus on charges
readout_modules = {
    # key is the name of output property, any naming is allowed
    "charges": ReadoutModule(
        pooling="atoms",
        postprocess="compute_partial_charges",
        # 2 layers
        layers=[
            ForwardLayer(
                hidden_feature_size=512,
                activation_function="ReLU",
            )
        ] * 2,
    )
}

# bring it all together
model_config = ModelConfig(
    version="0.1",
    atom_features=atom_features,
    convolution=convolution_module,
    readouts=readout_modules,
)

#### Defining a DataConfig

We can then define our dataset configs. Here we also have to specify our training targets.

In [8]:
target = GeneralLinearFitTarget(
    # what we're using to evaluate loss
    target_label="prediction_vector",
    # the output of the GNN we use to evaluate loss
    prediction_label="charges",
    # the column in the table that contains the precursor matrix
    design_matrix_column="precursor_matrix",
    # how we want to evaluate loss, e.g. RMSE, MSE, ...
    metric="rmse",
    # how much to weight this target
    # helps with scaling in multi-target optimizations
    weight=1,
    denominator=1,
)

training_to_electric_field = DatasetConfig(
    sources=["training_dataset_table.parquet"],
    targets=[target],
    batch_size=100,
)
validating_to_electric_field = DatasetConfig(
    sources=["validation_dataset_table.parquet"],
    targets=[target],
    batch_size=100,
)

# bringing it together
data_config = DataConfig(
    training=training_to_electric_field,
    validation=validating_to_electric_field
)

#### Defining an OptimizerConfig

The optimizer config is relatively simple; the only moving part here currently is the learning rate.

In [9]:
optimizer_config = OptimizerConfig(optimizer="Adam", learning_rate=0.001)

#### Creating a TrainingConfig

In [10]:
training_config = TrainingConfig(
    model=model_config,
    data=data_config,
    optimizer=optimizer_config
)

### Creating a TrainingGNNModel

Now we can create a `TrainingGNNModel`, which allows easy training of a `GNNModel`. The `GNNModel` can be accessed through `TrainingGNNModel.model`.

In [11]:
training_model = TrainingGNNModel(training_config)
training_model

TrainingGNNModel(
  (model): GNNModel(
    (convolution_module): ConvolutionModule(
      (gcn_layers): SAGEConvStack(
        (0): SAGEConv(
          (feat_drop): Dropout(p=0.0, inplace=False)
          (activation): ReLU()
          (fc_neigh): Linear(in_features=20, out_features=512, bias=False)
          (fc_self): Linear(in_features=20, out_features=512, bias=True)
        )
        (1-5): 5 x SAGEConv(
          (feat_drop): Dropout(p=0.0, inplace=False)
          (activation): ReLU()
          (fc_neigh): Linear(in_features=512, out_features=512, bias=False)
          (fc_self): Linear(in_features=512, out_features=512, bias=True)
        )
      )
    )
    (readout_modules): ModuleDict(
      (charges): ReadoutModule(
        (pooling_layer): PoolAtomFeatures()
        (readout_layers): SequentialLayers(
          (0): Linear(in_features=512, out_features=512, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=512, 

We can look at the initial capabilities of the model by comparing its charges to AM1-BCC charges. They're pretty bad.

In [12]:
test_molecule = Molecule.from_smiles("CCCBr")
test_molecule.assign_partial_charges("am1bcc")
reference_charges = test_molecule.partial_charges.m

# switch to eval mode
training_model.model.eval()

with torch.no_grad():
    nagl_charges_1 = training_model.model.compute_properties(
        test_molecule,
        as_numpy=True
    )["charges"]

# switch back to training mode
training_model.model.train()

# compare charges
differences = reference_charges - nagl_charges_1
differences

array([-0.40199522, -0.57868662, -0.93526483,  0.20943655,  0.58222105,
        0.58222105,  0.58222105, -0.04509645, -0.04509645,  0.02502038,
        0.02502038])

### Training the model

We use Pytorch Lightning to train.

In [13]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar

In [14]:
trainer = pl.Trainer(
    max_epochs=100,
    callbacks=[TQDMProgressBar()], # add progress bar
    accelerator="cpu"
)

GPU available: False, used: False


TPU available: False, using: 0 TPU cores


üí° Tip: For seamless cloud logging and experiment tracking, try installing [litlogger](https://pypi.org/project/litlogger/) to enable LitLogger, which logs metrics and artifacts automatically to the Lightning Experiments platform.


In [15]:
datamodule = training_model.create_data_module(verbose=False)

In [16]:
trainer.fit(
    training_model,
    datamodule=datamodule,
)

üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.


Featurizing dataset: 0it [00:00, ?it/s]




Featurizing batch:   0%|          | 0/40 [00:00<?, ?it/s]

[A




Featurizing batch:  15%|‚ñà‚ñå        | 6/40 [00:00<00:00, 57.66it/s]

[A




Featurizing batch:  32%|‚ñà‚ñà‚ñà‚ñé      | 13/40 [00:00<00:00, 59.78it/s]

[A




Featurizing batch:  55%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 22/40 [00:00<00:00, 69.37it/s]

[A




Featurizing batch:  72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 29/40 [00:00<00:00, 67.99it/s]

[A




Featurizing batch:  95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 38/40 [00:00<00:00, 74.99it/s]

[A

Featurizing batch: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 40/40 [00:00<00:00, 70.88it/s]


Featurizing dataset: 1it [00:00,  1.74it/s]

Featurizing dataset: 1it [00:00,  1.73it/s]




Featurizing dataset: 0it [00:00, ?it/s]




Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]

[A




Featurizing batch:  80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 8/10 [00:00<00:00, 79.99it/s]

[A

Featurizing batch: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10/10 [00:00<00:00, 78.73it/s]


Featurizing dataset: 1it [00:00,  7.61it/s]

Featurizing dataset: 1it [00:00,  7.57it/s]




2026-02-05 00:59:55.066531: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Featurizing dataset: 0it [00:00, ?it/s]




Featurizing batch:   0%|          | 0/40 [00:00<?, ?it/s]

[A




Featurizing batch:  15%|‚ñà‚ñå        | 6/40 [00:00<00:00, 59.10it/s]

[A




Featurizing batch:  32%|‚ñà‚ñà‚ñà‚ñé      | 13/40 [00:00<00:00, 60.95it/s]

[A




Featurizing batch:  55%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 22/40 [00:00<00:00, 69.72it/s]

[A




Featurizing batch:  72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 29/40 [00:00<00:00, 67.89it/s]

[A




Featurizing batch: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 40/40 [00:00<00:00, 81.16it/s]

[A

Featurizing batch: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 40/40 [00:00<00:00, 74.01it/s]


Featurizing dataset: 1it [00:00,  1.82it/s]

Featurizing dataset: 1it [00:00,  1.82it/s]




Featurizing dataset: 0it [00:00, ?it/s]




Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]

[A

Featurizing batch: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10/10 [00:00<00:00, 109.51it/s]


Featurizing dataset: 1it [00:00, 10.43it/s]




Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/runner/micromamba/envs/openff-docs-examples/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:317: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


We can now check the charges again. They should have improved, especially if you used a larger dataset (note: results may vary with the small 50-molecule dataset this notebook has chosen to use, for reasons of speed).

In [17]:
# switch to eval mode
training_model.model.eval()

with torch.no_grad():
    nagl_charges_2 = training_model.model.compute_properties(
        test_molecule,
        as_numpy=True
    )["charges"]

differences_after_training = reference_charges - nagl_charges_2
differences_after_training

array([-0.15572474, -0.09356335, -0.04113645, -0.04971532,  0.02554091,
        0.02554091,  0.02554091,  0.03322632,  0.03322632,  0.09853192,
        0.09853192])