# Train a GNN directly to an electric field

To execute this example fully, the following packages are required.

* openff-nagl
* openff-recharge
* openff-qcsubmit
* psi4

However, if you wish to just follow along the training part without first creating the training datasets yourself, you can get away with just `openff-nagl` installed and simply load the training/validation data from the provided `.parquet` files. The commands are provided at the end of the "Generate and format training data" section, but commented out.

In [1]:
import tqdm

from qcportal import PortalClient
from openff.units import unit

from openff.toolkit import Molecule
from openff.qcsubmit.results import BasicResultCollection
from openff.recharge.esp.storage import MoleculeESPRecord
from openff.recharge.esp.qcresults import from_qcportal_results
from openff.recharge.grids import MSKGridSettings
from openff.recharge.utilities.geometry import compute_vector_field

import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np

import torch

## Generate and format training data


### Downloading from QCArchive
First, we will create training data. We'll download a smaller training set for the purposes of this example.

In [2]:
qc_client = PortalClient("https://api.qcarchive.molssi.org:443", cache_dir=".")

# download dataset from QCPortal
br_esps_collection = BasicResultCollection.from_server(
    client=qc_client,
    datasets="OpenFF multi-Br ESP Fragment Conformers v1.1",
    spec_name="HF/6-31G*",
)

records_and_molecules = br_esps_collection.to_records()





### Converting to MoleculeESPRecords

Now we convert to OpenFF Recharge records and compute the ESPs and electric fields of each molecule.

In [3]:
# Create OpenFF Recharge MoleculeESPRecords
grid_settings = MSKGridSettings()

# this can take a while; set records_and_molecules[:10]
# to only use the first 10
molecule_esp_records = [
    from_qcportal_results(
        qc_result=qcrecord,
        qc_molecule=qcrecord.molecule,
        qc_keyword_set=qcrecord.specification.keywords,
        grid_settings=grid_settings,
        compute_field=True
    )
    for qcrecord, _ in tqdm.tqdm(records_and_molecules[:50])
]

  0%|          | 0/50 [00:00<?, ?it/s]

  2%|▏         | 1/50 [00:10<08:11, 10.03s/it]

  4%|▍         | 2/50 [00:18<07:17,  9.11s/it]

  6%|▌         | 3/50 [00:27<07:14,  9.24s/it]

  8%|▊         | 4/50 [00:35<06:32,  8.54s/it]

 10%|█         | 5/50 [00:42<05:58,  7.96s/it]

 12%|█▏        | 6/50 [00:52<06:21,  8.67s/it]

 14%|█▍        | 7/50 [01:00<06:09,  8.60s/it]

 16%|█▌        | 8/50 [01:08<05:43,  8.17s/it]

 18%|█▊        | 9/50 [01:18<05:58,  8.74s/it]

 20%|██        | 10/50 [01:25<05:28,  8.21s/it]

 22%|██▏       | 11/50 [01:33<05:18,  8.17s/it]

 24%|██▍       | 12/50 [01:39<04:48,  7.60s/it]

 26%|██▌       | 13/50 [01:45<04:22,  7.09s/it]

 28%|██▊       | 14/50 [01:50<03:56,  6.58s/it]

 30%|███       | 15/50 [02:00<04:19,  7.42s/it]

 32%|███▏      | 16/50 [02:09<04:28,  7.91s/it]

 34%|███▍      | 17/50 [02:14<03:53,  7.06s/it]

 36%|███▌      | 18/50 [02:22<03:57,  7.42s/it]

 38%|███▊      | 19/50 [02:30<03:52,  7.50s/it]

 40%|████      | 20/50 [02:41<04:17,  8.57s/it]

 42%|████▏     | 21/50 [02:49<04:04,  8.43s/it]

 44%|████▍     | 22/50 [02:57<03:51,  8.26s/it]

 46%|████▌     | 23/50 [03:05<03:39,  8.15s/it]

 48%|████▊     | 24/50 [03:13<03:30,  8.08s/it]

 50%|█████     | 25/50 [03:20<03:19,  7.97s/it]

 52%|█████▏    | 26/50 [03:25<02:47,  6.98s/it]

 54%|█████▍    | 27/50 [03:31<02:34,  6.73s/it]

 56%|█████▌    | 28/50 [03:37<02:24,  6.56s/it]

 58%|█████▊    | 29/50 [03:44<02:21,  6.72s/it]

 60%|██████    | 30/50 [03:52<02:18,  6.91s/it]

 62%|██████▏   | 31/50 [03:56<01:59,  6.27s/it]

 64%|██████▍   | 32/50 [04:02<01:48,  6.04s/it]

 66%|██████▌   | 33/50 [04:08<01:40,  5.91s/it]

 68%|██████▊   | 34/50 [04:13<01:32,  5.76s/it]

 70%|███████   | 35/50 [04:19<01:28,  5.92s/it]

 72%|███████▏  | 36/50 [04:25<01:22,  5.86s/it]

 74%|███████▍  | 37/50 [04:32<01:19,  6.08s/it]

 76%|███████▌  | 38/50 [04:38<01:15,  6.29s/it]

 78%|███████▊  | 39/50 [04:45<01:11,  6.47s/it]

 80%|████████  | 40/50 [04:52<01:05,  6.57s/it]

 82%|████████▏ | 41/50 [04:59<00:58,  6.55s/it]

 84%|████████▍ | 42/50 [05:06<00:53,  6.70s/it]

 86%|████████▌ | 43/50 [05:13<00:47,  6.77s/it]

 88%|████████▊ | 44/50 [05:19<00:39,  6.59s/it]

 90%|█████████ | 45/50 [05:25<00:32,  6.54s/it]

 92%|█████████▏| 46/50 [05:32<00:26,  6.59s/it]

 94%|█████████▍| 47/50 [05:38<00:19,  6.60s/it]

 96%|█████████▌| 48/50 [05:44<00:12,  6.34s/it]

 98%|█████████▊| 49/50 [05:51<00:06,  6.59s/it]

100%|██████████| 50/50 [05:58<00:00,  6.70s/it]

100%|██████████| 50/50 [05:58<00:00,  7.18s/it]




### Convert to PyArrow dataset

NAGL reads in and trains to data from [PyArrow tables](https://arrow.apache.org/docs/python/getstarted.html#creating-arrays-and-tables). Below we do some conversion of each electric field to fit a basic GeneralLinearFit target, which fits the equation Ax = b. To avoid carrying too many data points around, we can do some postprocessing by first flattening the electric field matrix from 3 dimensions to 2, then multiplying by the transpose of A.

$$\mathbf{A}\vec{x} = \vec{b}$$
$$\mathbf{A^{T}}\mathbf{A}\vec{x} = \mathbf{A^{T}}\vec{b}$$

In [4]:
pyarrow_entries = []
for molecule_esp_record in tqdm.tqdm(molecule_esp_records):
    electric_field = molecule_esp_record.electric_field # in atomic units
    grid = molecule_esp_record.grid_coordinates * unit.angstrom
    conformer = molecule_esp_record.conformer * unit.angstrom

    vector_field = compute_vector_field(
        conformer.m_as(unit.bohr),  # shape: M x 3
        grid.m_as(unit.bohr),  # shape: N x 3
    ) # N x 3 x M

    # postprocess so we're not carrying around millions of floats
    # firstly flatten out N x 3 -> 3N x M
    vector_field_2d = np.concatenate(vector_field, axis=0)
    electric_field_1d = np.concatenate(electric_field, axis=0)

    # now multiply by vector_field_2d's transpose
    new_precursor_matrix = vector_field_2d.T @ vector_field_2d
    new_field_vector = vector_field_2d.T @ electric_field_1d

    n_atoms = conformer.shape[0]
    assert new_precursor_matrix.shape == (n_atoms, n_atoms)
    assert new_field_vector.shape == (n_atoms,)

    # create entry. These columns are essential
    # mapped_smiles is essential for every target
    entry = {
        "mapped_smiles": molecule_esp_record.tagged_smiles,
        "precursor_matrix": new_precursor_matrix.flatten().tolist(),
        "prediction_vector": new_field_vector.tolist()
    }
    pyarrow_entries.append(entry)


# arbitrarily split into training and validation datasets
training_pyarrow_entries = pyarrow_entries[:-10]
validation_pyarrow_entries = pyarrow_entries[-10:]

training_table = pa.Table.from_pylist(training_pyarrow_entries)
validation_table = pa.Table.from_pylist(validation_pyarrow_entries)
training_table

  0%|          | 0/50 [00:00<?, ?it/s]

 74%|███████▍  | 37/50 [00:00<00:00, 369.32it/s]

100%|██████████| 50/50 [00:00<00:00, 355.41it/s]




pyarrow.Table
mapped_smiles: string
precursor_matrix: list<item: double>
  child 0, item: double
prediction_vector: list<item: double>
  child 0, item: double
----
mapped_smiles: [["[H:1][N:2]1[C:3]([Br:4])=[N:5][c:6]2[n:7][c:8]([Br:9])[c:10]([Br:11])[n:12][c:13]21","[H:1][N:2]([H:3])[C:4]1([C:5](=[O:6])[O-:7])[C:8]([H:9])([H:10])[C:11]2([H:12])[C:13]([Br:14])([Br:15])[C:16]2([H:17])[C:18]1([H:19])[H:20]","[H:1][C:2]1([H:3])[C:4]2([H:5])[C:6]([Br:7])([Br:8])[C:9]2([H:10])[C:11]([H:12])([H:13])[C:14]1([C:15](=[O:16])[O-:17])[N+:18]([H:19])([H:20])[H:21]","[H:1][N:2]1[C:3]([H:4])([H:5])[C:6]2([H:7])[C:8]([Br:9])([Br:10])[C:11]2([H:12])[C:13]1([H:14])[C:15](=[O:16])[O-:17]","[H:1][C:2]12[C:3]([Br:4])([Br:5])[C:6]1([H:7])[C:8]([H:9])([C:10](=[O:11])[O-:12])[N+:13]([H:14])([H:15])[C:16]2([H:17])[H:18]",...,"[H:1][c:2]1[c:3]([Br:4])[c:5]([H:6])[c:7]2[c:8]([c:9]1[Br:10])[C:11]([H:12])([H:13])[C:14]([H:15])([H:16])[C:17]2([H:18])[H:19]","[H:1][c:2]1[c:3]([H:4])[c:5]2[c:6]([c:7]([Br:8])[c:9]1[B

In [5]:
pq.write_table(training_table, "training_dataset_table.parquet")
pq.write_table(validation_table, "validation_dataset_table.parquet")

# # to read back in -- note, the files saved here give the full dataset, not the 50 record subset
# training_table = pq.read_table("training_dataset_table.parquet")
# validation_table = pq.read_table("validation_dataset_table.parquet")

## Set up for training a GNN

In [6]:
from openff.nagl.config import (
    TrainingConfig,
    OptimizerConfig,
    ModelConfig,
    DataConfig
)
from openff.nagl.config.model import (
    ConvolutionModule, ReadoutModule,
    ConvolutionLayer, ForwardLayer,
)
from openff.nagl.config.data import DatasetConfig
from openff.nagl.training.training import TrainingGNNModel
from openff.nagl.features.atoms import (
    AtomicElement,
    AtomConnectivity,
    AtomInRingOfSize,
    AtomAverageFormalCharge,
)

from openff.nagl.training.loss import GeneralLinearFitTarget

### Defining the training config

#### Defining a ModelConfig

First we define a ModelConfig.
This can be done in Python, but in practice it is probably easier to define the model in a YAML file and load it with `ModelConfig.from_yaml`.

In [7]:
atom_features = [
    AtomicElement(categories=["H", "C", "N", "O", "F", "Br", "S", "P", "I"]),
    AtomConnectivity(categories=[1, 2, 3, 4, 5, 6]),
    AtomInRingOfSize(ring_size=3),
    AtomInRingOfSize(ring_size=4),
    AtomInRingOfSize(ring_size=5),
    AtomInRingOfSize(ring_size=6),
    AtomAverageFormalCharge(),
]

# define our convolution module
convolution_module = ConvolutionModule(
    architecture="SAGEConv",
    # construct 6 layers with dropout 0 (default),
    # hidden feature size 512, and ReLU activation function
    # these layers can also be individually specified,
    # but we just duplicate the layer 6 times for identical layers
    layers=[
        ConvolutionLayer(
            hidden_feature_size=512,
            activation_function="ReLU",
            aggregator_type="mean"
        )
    ] * 6,
)

# define our readout module/s
# multiple are allowed but let's focus on charges
readout_modules = {
    # key is the name of output property, any naming is allowed
    "charges": ReadoutModule(
        pooling="atoms",
        postprocess="compute_partial_charges",
        # 2 layers
        layers=[
            ForwardLayer(
                hidden_feature_size=512,
                activation_function="ReLU",
            )
        ] * 2,
    )
}

# bring it all together
model_config = ModelConfig(
    version="0.1",
    atom_features=atom_features,
    convolution=convolution_module,
    readouts=readout_modules,
)

#### Defining a DataConfig

We can then define our dataset configs. Here we also have to specify our training targets.

In [8]:
target = GeneralLinearFitTarget(
    # what we're using to evaluate loss
    target_label="prediction_vector",
    # the output of the GNN we use to evaluate loss
    prediction_label="charges",
    # the column in the table that contains the precursor matrix
    design_matrix_column="precursor_matrix",
    # how we want to evaluate loss, e.g. RMSE, MSE, ...
    metric="rmse",
    # how much to weight this target
    # helps with scaling in multi-target optimizations
    weight=1,
    denominator=1,
)

training_to_electric_field = DatasetConfig(
    sources=["training_dataset_table.parquet"],
    targets=[target],
    batch_size=100,
)
validating_to_electric_field = DatasetConfig(
    sources=["validation_dataset_table.parquet"],
    targets=[target],
    batch_size=100,
)

# bringing it together
data_config = DataConfig(
    training=training_to_electric_field,
    validation=validating_to_electric_field
)

#### Defining an OptimizerConfig

The optimizer config is relatively simple; the only moving part here currently is the learning rate.

In [9]:
optimizer_config = OptimizerConfig(optimizer="Adam", learning_rate=0.001)

#### Creating a TrainingConfig

In [10]:
training_config = TrainingConfig(
    model=model_config,
    data=data_config,
    optimizer=optimizer_config
)

### Creating a TrainingGNNModel

Now we can create a `TrainingGNNModel`, which allows easy training of a `GNNModel`. The `GNNModel` can be accessed through `TrainingGNNModel.model`.

In [11]:
training_model = TrainingGNNModel(training_config)
training_model

TrainingGNNModel(
  (model): GNNModel(
    (convolution_module): ConvolutionModule(
      (gcn_layers): SAGEConvStack(
        (0): SAGEConv(
          (feat_drop): Dropout(p=0.0, inplace=False)
          (activation): ReLU()
          (fc_neigh): Linear(in_features=20, out_features=512, bias=False)
          (fc_self): Linear(in_features=20, out_features=512, bias=True)
        )
        (1-5): 5 x SAGEConv(
          (feat_drop): Dropout(p=0.0, inplace=False)
          (activation): ReLU()
          (fc_neigh): Linear(in_features=512, out_features=512, bias=False)
          (fc_self): Linear(in_features=512, out_features=512, bias=True)
        )
      )
    )
    (readout_modules): ModuleDict(
      (charges): ReadoutModule(
        (pooling_layer): PoolAtomFeatures()
        (readout_layers): SequentialLayers(
          (0): Linear(in_features=512, out_features=512, bias=True)
          (1): ReLU()
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=512, 

We can look at the initial capabilities of the model by comparing its charges to AM1-BCC charges. They're pretty bad.

In [12]:
test_molecule = Molecule.from_smiles("CCCBr")
test_molecule.assign_partial_charges("am1bcc")
reference_charges = test_molecule.partial_charges.m

# switch to eval mode
training_model.model.eval()

with torch.no_grad():
    nagl_charges_1 = training_model.model.compute_properties(
        test_molecule,
        as_numpy=True
    )["charges"]

# switch back to training mode
training_model.model.train()

# compare charges
differences = reference_charges - nagl_charges_1
differences

array([-0.57983828, -1.28828087, -0.28952769,  0.10409178,  0.25788712,
        0.25788712,  0.25788712,  0.2761532 ,  0.2761532 ,  0.36379363,
        0.36379363])

### Training the model

We use Pytorch Lightning to train.

In [13]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar

In [14]:
trainer = pl.Trainer(
    max_epochs=100,
    callbacks=[TQDMProgressBar()], # add progress bar
    accelerator="cpu"
)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.


GPU available: False, used: False


TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs


In [15]:
datamodule = training_model.create_data_module(verbose=False)

In [16]:
trainer.fit(
    training_model,
    datamodule=datamodule,
)

Featurizing dataset: 0it [00:00, ?it/s]




Featurizing batch:   0%|          | 0/40 [00:00<?, ?it/s]

[A




Featurizing batch:  22%|██▎       | 9/40 [00:00<00:00, 84.88it/s]

[A




Featurizing batch:  45%|████▌     | 18/40 [00:00<00:00, 86.48it/s]

[A




Featurizing batch:  70%|███████   | 28/40 [00:00<00:00, 86.36it/s]

[A




Featurizing batch: 100%|██████████| 40/40 [00:00<00:00, 96.14it/s]

[A

Featurizing batch: 100%|██████████| 40/40 [00:00<00:00, 92.20it/s]


Featurizing dataset: 1it [00:00,  2.26it/s]

Featurizing dataset: 1it [00:00,  2.26it/s]




Featurizing dataset: 0it [00:00, ?it/s]




Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]

[A

Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 109.36it/s]


Featurizing dataset: 1it [00:00, 10.52it/s]




2025-09-30 00:24:24.787390: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Featurizing dataset: 0it [00:00, ?it/s]




Featurizing batch:   0%|          | 0/40 [00:00<?, ?it/s]

[A




Featurizing batch:  22%|██▎       | 9/40 [00:00<00:00, 88.56it/s]

[A




Featurizing batch:  45%|████▌     | 18/40 [00:00<00:00, 89.01it/s]

[A




Featurizing batch:  68%|██████▊   | 27/40 [00:00<00:00, 89.33it/s]

[A




Featurizing batch:  95%|█████████▌| 38/40 [00:00<00:00, 94.46it/s]

[A

Featurizing batch: 100%|██████████| 40/40 [00:00<00:00, 92.78it/s]


Featurizing dataset: 1it [00:00,  2.29it/s]

Featurizing dataset: 1it [00:00,  2.28it/s]




Featurizing dataset: 0it [00:00, ?it/s]




Featurizing batch:   0%|          | 0/10 [00:00<?, ?it/s]

[A

Featurizing batch: 100%|██████████| 10/10 [00:00<00:00, 108.58it/s]


Featurizing dataset: 1it [00:00, 10.43it/s]



  | Name  | Type     | Params | Mode 
-------------------------------------------
0 | model | GNNModel | 3.2 M  | train
-------------------------------------------
3.2 M     Trainable params
0         Non-trainable params
3.2 M     Total params
12.685    Total estimated model params size (MB)
47        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/runner/micromamba/envs/openff-docs-examples/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


/home/runner/micromamba/envs/openff-docs-examples/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/runner/micromamba/envs/openff-docs-examples/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


We can now check the charges again. They should have improved, especially if you used a larger dataset (note: results may vary with the small 50-molecule dataset this notebook has chosen to use, for reasons of speed).

In [17]:
# switch to eval mode
training_model.model.eval()

with torch.no_grad():
    nagl_charges_2 = training_model.model.compute_properties(
        test_molecule,
        as_numpy=True
    )["charges"]

differences_after_training = reference_charges - nagl_charges_2
differences_after_training

array([-0.02446773, -0.01090277,  0.01771042, -0.04515357, -0.027538  ,
       -0.027538  , -0.027538  ,  0.00291761,  0.00291761,  0.06979619,
        0.06979619])