# Train a GNN directly to predict torsion energies

**Note: this is an *experimental* notebook only to demonstrate a proof-of-concept.
While some parts of this notebook may eventually be fully supported by OpenFF-NAGL, the general conclusion arrived at is that it is likely easiest to work at this solution outside of the NAGL framework.**

To execute this example fully, the following packages are required.

* openff-nagl
* openff-recharge
* openff-qcsubmit
* psi4

However, if you wish to just follow along the training part without first creating the training datasets yourself, you can get away with just `openff-nagl` installed and simply load the training/validation data from the provided `.parquet` files. The commands are provided at the end of the "Generate and format training data" section, but commented out.

In [1]:
import collections
import tqdm

from qcportal import PortalClient
from openff.units import unit

from openff.toolkit import Molecule, ForceField
from openff.qcsubmit.results import BasicResultCollection
from openff.recharge.esp.storage import MoleculeESPRecord
from openff.recharge.esp.qcresults import from_qcportal_results
from openff.recharge.grids import MSKGridSettings
from openff.recharge.utilities.geometry import compute_vector_field

import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np

import torch
import MDAnalysis as mda

## Generate and format training data


### Downloading from QCArchive
First, we will create training data. We'll download a smaller training set for the purposes of this example.

In [2]:
# qc_client = PortalClient("https://api.qcarchive.molssi.org:443", cache_dir=".")

# # download dataset from QCPortal
# br_esps_collection = BasicResultCollection.from_server(
#     client=qc_client,
#     datasets="OpenFF multi-Br ESP Fragment Conformers v1.1",
#     spec_name="HF/6-31G*",
# )

# records_and_molecules = br_esps_collection.to_records()

### Convert to PyArrow dataset

NAGL reads in and trains to data from [PyArrow tables](https://arrow.apache.org/docs/python/getstarted.html#creating-arrays-and-tables). Below we create some easy data by using an existing force field to assign torsion energies to each individual torsion, which then gets summed over the central bond.

*Note: the maths could probably use double-checking.*

In [3]:
def calc_torsion_energy(angle, parameter):
    angle = (angle * unit.degrees).m_as(unit.radians)
    total = 0 * unit.kilojoules_per_mole
    for k, phase, periodicity in zip(parameter.k, parameter.phase, parameter.periodicity):
        phase = phase.m_as(unit.radians)
        subtotal = k * (1 + np.cos(periodicity * angle - phase))
        total += subtotal
    return total


def get_central_bond_torsions(mol, forcefield):
    labels = forcefield.label_molecules(mol.to_topology())[0]["ProperTorsions"]
    u = mda.Universe(mol.to_rdkit())

    bonds_to_dihedrals = collections.defaultdict(list)
    for key in labels:
        center = tuple(sorted(key[1:3]))
        bonds_to_dihedrals[center].append(key)

    # sort bonds...
    energies = []
    keys = sorted([
        tuple(sorted([bond.atom1_index, bond.atom2_index]))
        for bond in mol.bonds
    ])
    for key in keys:
        dihedrals = bonds_to_dihedrals[key]
        
        energy = 0 * unit.kilojoules_per_mole
        for dihedral_indices in dihedrals:
            dihedral_parameter = labels[dihedral_indices]
            angle = u.atoms[list(dihedral_indices)].dihedral.value()
            energy += calc_torsion_energy(angle, dihedral_parameter)

        energies.append(energy.m_as(unit.kilojoules_per_mole))
    return energies

In [4]:
pyarrow_entries = []
forcefield = ForceField("openff-2.1.0.offxml")
# for record, molecule in tqdm.tqdm(records_and_molecules):
#     central_torsion_energies = get_central_bond_torsions(molecule, forcefield)
#     entry = {
#         # hopefully this preserves bond order
#         "mapped_smiles": molecule.to_smiles(mapped=True),
#         "torsion_energies": central_torsion_energies,
#         "conformer": molecule.conformers[0].m_as(unit.angstrom).flatten().tolist(),
#     }
#     pyarrow_entries.append(entry)


# # arbitrarily split into training and validation datasets
# training_pyarrow_entries = pyarrow_entries[:-10]
# validation_pyarrow_entries = pyarrow_entries[-10:]

# training_table = pa.Table.from_pylist(training_pyarrow_entries)
# validation_table = pa.Table.from_pylist(validation_pyarrow_entries)
# training_table

In [5]:
# pq.write_table(training_table, "training_dataset_table.parquet")
# pq.write_table(validation_table, "validation_dataset_table.parquet")

# to read back in -- note, the files saved here give the full dataset, not the 50 record subset
training_table = pq.read_table("training_dataset_table.parquet")
validation_table = pq.read_table("validation_dataset_table.parquet")

## Set up for training a GNN

In [6]:
from openff.nagl.config import (
    TrainingConfig,
    OptimizerConfig,
    ModelConfig,
    DataConfig
)
from openff.nagl.config.model import (
    ConvolutionModule, ReadoutModule,
    ConvolutionLayer, ForwardLayer,
)
from openff.nagl.config.data import DatasetConfig
from openff.nagl.training.training import TrainingGNNModel
from openff.nagl.features.atoms import (
    AtomicElement,
    AtomConnectivity,
    AtomInRingOfSize,
    AtomAverageFormalCharge,
)

from openff.nagl.training.loss import GeneralLinearFitTarget

### Defining the training config

#### Defining a ModelConfig

First we define a ModelConfig. This is done in Python so we can define custom PostprocessLayers to compute c_ij coefficients, and a custom bond feature pooling layer that takes alpha as an input.

Caveats:
- both these custom layers use *new features* implemented in this branch which are currently unsupported by OpenFF NAGL proper.
- everything assumes that bonds/angles/torsions etc are properly sorted (which is how it's implemented in the branch). Anything else would require more accounting
- The current implementation in NAGL doesn't allow for multiple molecules at the moment.

*Also note, again the maths below could probably use double-checking.*

In [7]:
from openff.nagl.nn.postprocess import PostprocessLayer
from openff.nagl.nn._pooling import PoolProperTorsionFeatures, PoolBondFeatures
from collections import defaultdict

class ComputeSCoefficients(PostprocessLayer):
    """Computes c_ij"""

    name: str = "compute_s_coefficients"
    n_features: int = 1

    def __init__(self, pooling_layer = None):
        super().__init__()
        self._pooling_layer = pooling_layer

    def forward(
        self,
        molecule,
        inputs: torch.Tensor,
        **kwargs
    ):
        c_ij = inputs[:, 0] # (n_torsions, 1)
        c_ij = torch.flatten(c_ij)  # (n_torsions,)

        d_ij = self._pooling_layer._calculate_internal_coordinates(molecule)
        s_ij = torch.empty((2, *d_ij.shape), dtype=d_ij.dtype)
        s_ij[0, :] = torch.cos(d_ij)
        s_ij[1, :] = torch.sin(d_ij)
        s_ij = s_ij.T # (n_torsions, 2)

        # filter by central bond
        proper_torsion_indices_T = molecule._pooling_representations["proper_torsion"]

        bond_indices = defaultdict(list)
        for i, atom_2 in enumerate(
            proper_torsion_indices_T[1]
        ):
            atom_3 = proper_torsion_indices_T[2][i]
            bond = tuple(sorted([atom_2.item(), atom_3.item()]))
            bond_indices[bond].append(i)

        a_dict = {}
        
        for bond, indices in bond_indices.items():
            s = torch.sum(
                c_ij[indices].reshape((-1, 1)) * s_ij[indices],
                dim=0
            )
            s_norm = torch.norm(s)
            a = torch.arctan2(*(s / s_norm))
            a_dict[bond] = a

        # ... set bonds that aren't central bonds in torsions to 0?
        all_bond_indices = molecule._get_bonds()
        for key in all_bond_indices:
            if key not in a_dict:
                a_dict[key] = torch.tensor(0)

        # sort
        sorted_keys = sorted(a_dict)
        alphas = torch.empty((len(a_dict),)).flatten()
        for i, key in enumerate(sorted_keys):
            alphas[i] = a_dict[key].item()

        # note if molecules are batched, bond indices will be cumulative
        return alphas


class InjectablePoolBondFeatures(PoolBondFeatures):
    name = "injectable_bond"
    def _get_final_representations(self, molecule, readouts=None, **kwargs):
        representations = self._get_pooled_representations(molecule)

        # assume bonds are properly sorted
        alphas = readouts["alpha"].reshape((-1, 1))

        representations = [
            torch.cat([representation, alphas], dim=1)
            for representation in representations
        ]
        return representations

Now the normal definition of a model. Note this uses the same features and general architecture as the NAGL model used for AM1-BCC partial charges.

The readout modules calculates *two* properties: 1) alpha and b) energies.

The GNNModel version is '0.2' to be incompatible with what is currently supported in OpenFF NAGL.

In [8]:
atom_features = [
    AtomicElement(categories=["H", "C", "N", "O", "F", "Br", "S", "P", "I"]),
    AtomConnectivity(categories=[1, 2, 3, 4, 5, 6]),
    AtomInRingOfSize(ring_size=3),
    AtomInRingOfSize(ring_size=4),
    AtomInRingOfSize(ring_size=5),
    AtomInRingOfSize(ring_size=6),
    AtomAverageFormalCharge(),
]

# define our convolution module
convolution_module = ConvolutionModule(
    architecture="SAGEConv",
    # construct 6 layers with dropout 0 (default),
    # hidden feature size 512, and ReLU activation function
    # these layers can also be individually specified,
    # but we just duplicate the layer 6 times for identical layers
    layers=[
        ConvolutionLayer(
            hidden_feature_size=512,
            activation_function="ReLU",
            aggregator_type="mean"
        )
    ] * 6,
)

# define our readout module/s
# multiple are allowed but let's focus on charges
readout_modules = {
    # key is the name of output property, any naming is allowed
    "alpha": ReadoutModule(
        pooling="proper_torsion",
        postprocess=ComputeSCoefficients(),
        # 2 layers
        layers=[
            ForwardLayer(
                hidden_feature_size=512,
                activation_function="ReLU",
            )
        ] * 2,
    ),
    "energies": ReadoutModule(
        pooling=InjectablePoolBondFeatures,
        layers=[
            ForwardLayer(
                hidden_feature_size=512,
                activation_function="ReLU",
            ),
            ForwardLayer(
                hidden_feature_size=512,
                activation_function="ReLU",
            ),
            ForwardLayer(
                hidden_feature_size=1,
                activation_function="Identity",
            )
        ]
    )
}

# bring it all together
model_config = ModelConfig(
    version="0.2",
    atom_features=atom_features,
    convolution=convolution_module,
    readouts=readout_modules,
    include_xyz=True,
)

#### Defining a DataConfig

We can then define our dataset configs. Here we also have to specify our training targets.

In [9]:
from openff.nagl.training.loss import ReadoutTarget

target = ReadoutTarget(
    # what we're using to evaluate loss
    target_label="torsion_energies",
    # the output of the GNN we use to evaluate loss
    prediction_label="energies",
    # how we want to evaluate loss, e.g. RMSE, MSE, ...
    metric="rmse",
    # how much to weight this target
    # helps with scaling in multi-target optimizations
    weight=1,
    denominator=1,
)

training_to_torsions = DatasetConfig(
    sources=["training_dataset_table.parquet"],
    targets=[target],
    batch_size=100,
)
validating_to_torsions = DatasetConfig(
    sources=["validation_dataset_table.parquet"],
    targets=[target],
    batch_size=100,
)

# bringing it together
data_config = DataConfig(
    training=training_to_torsions,
    validation=validating_to_torsions
)

#### Defining an OptimizerConfig

The optimizer config is relatively simple; the only moving part here currently is the learning rate.

In [10]:
optimizer_config = OptimizerConfig(optimizer="Adam", learning_rate=0.001)

#### Creating a TrainingConfig

In [11]:
training_config = TrainingConfig(
    model=model_config,
    data=data_config,
    optimizer=optimizer_config
)

### Creating a TrainingGNNModel

Now we can create a `TrainingGNNModel`, which allows easy training of a `GNNModel`. The `GNNModel` can be accessed through `TrainingGNNModel.model`.

In [12]:
training_model = TrainingGNNModel(training_config)
training_model

TrainingGNNModel(
  (model): GNNModel(
    (convolution_module): ConvolutionModule(
      (gcn_layers): SAGEConvStack(
        (0): SAGEConv(
          (feat_drop): Dropout(p=0.0, inplace=False)
          (activation): ReLU()
          (fc_neigh): Linear(in_features=20, out_features=512, bias=False)
          (fc_self): Linear(in_features=20, out_features=512, bias=True)
        )
        (1-5): 5 x SAGEConv(
          (feat_drop): Dropout(p=0.0, inplace=False)
          (activation): ReLU()
          (fc_neigh): Linear(in_features=512, out_features=512, bias=False)
          (fc_self): Linear(in_features=512, out_features=512, bias=True)
        )
      )
    )
    (readout_modules): ModuleDict(
      (alpha): ReadoutModule(
        (pooling_layer): PoolProperTorsionFeatures(
          (layers): SequentialLayers(
            (0): Linear(in_features=2049, out_features=512, bias=True)
            (1): ReLU()
            (2): Dropout(p=0.0, inplace=False)
            (3): Linear(in_featu

We can look at the initial capabilities of the model by comparing its energies to reference data.

In [13]:
torsion_layer = training_model.model.readout_modules["alpha"].pooling_layer
# very hacky assignment to make current model work
# NAGL doesn't elegantly allow for passing this in during model creation
training_model.model.readout_modules["alpha"].postprocess_layer._pooling_layer = torsion_layer

In [14]:
test_molecule = Molecule.from_smiles("CCCBr")
test_molecule.generate_conformers(n_conformers=1)
reference_energies = get_central_bond_torsions(test_molecule, forcefield)

# switch to eval mode
training_model.model.eval()

with torch.no_grad():
    energies_1 = training_model.model.compute_properties(
        test_molecule,
        as_numpy=True
    )["energies"]

# switch back to training mode
training_model.model.train()

# compare charges
differences = reference_energies - energies_1
differences



Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at /Users/runner/miniforge3/conda-bld/libtorch_1719361031659/work/aten/src/ATen/native/Cross.cpp:66.)
  normal1 = torch.cross(ba, bc)


array([-0.19341571, -0.22279456, -0.22279456, -0.22279456,  7.53086858,
       -0.20101197, -0.20101197, -0.20901833, -0.21364443, -0.21364446])

In [15]:
reference_energies

[0.030007618485268825, 0, 0, 0, 7.7472088522614175, 0, 0, 0, 0, 0]

In [19]:
energies_1

array([0.22342333, 0.22279456, 0.22279456, 0.22279456, 0.21634027,
       0.20101197, 0.20101197, 0.20901833, 0.21364443, 0.21364446])

### Training the model

We use Pytorch Lightning to train.

In [20]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import TQDMProgressBar

In [21]:
trainer = pl.Trainer(
    max_epochs=100,
    callbacks=[TQDMProgressBar()], # add progress bar
    accelerator="cpu"
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [22]:
datamodule = training_model.create_data_module(verbose=False)

Currently there are an abundance of warnings about conformer geometries being in angstrom.

In [23]:
trainer.fit(
    training_model,
    datamodule=datamodule,
)

Featurizing dataset: 0it [00:00, ?it/s]




































































Featurizing batch: 100%|██████████████████████| 640/640 [00:07<00:00, 83.74it/s]
Featurizing dataset: 1it [00:07,  7.67s/it]
Featurizing dataset: 0it [00:00, ?it/s]

Featurizing batch: 100%|████████████████████████| 10/10 [00:00<00:00, 57.63it/s]
Featurizing dataset: 1it [00:00,  5.62it/s]
Missing logger folder: /Users/lily/pydev/openff-nagl/examples/train-gnn-dihedrals/lightning_logs
2024-11-26 08:47:51.062220: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Featurizing dataset: 0it [00:00, ?it/s]













































































Featurizing batch: 100%|██████████████████████| 640/640 [00:09<00:00, 

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/Users/lily/micromamba/envs/openff-nagl-test/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/lily/micromamba/envs/openff-nagl-test/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.
/Users/lily/micromamba/envs/openff-nagl-test/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (7) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                               | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


We can now check the energies again. They should have improved, especially if you used a larger dataset (note: results may vary with the small 50-molecule dataset this notebook has chosen to use, for reasons of speed).

Note, the energies are 

In [24]:
# switch to eval mode
training_model.model.eval()

with torch.no_grad():
    energies_2 = training_model.model.compute_properties(
        test_molecule,
        as_numpy=True
    )["energies"]

differences_after_training = reference_energies - energies_2
differences_after_training



array([-0.31952874,  0.00871341,  0.00871341,  0.00871341,  2.61271455,
       -0.0035692 , -0.0035692 ,  0.00833304,  0.02520727,  0.02520727])

In [25]:
# original
sum(differences ** 2)

57.11609162500679

In [26]:
sum(differences_after_training ** 2)

6.929969425240237

In [27]:
energies_2

array([ 3.49536359e-01, -8.71341489e-03, -8.71341489e-03, -8.71341489e-03,
        5.13449430e+00,  3.56920343e-03,  3.56920343e-03, -8.33304413e-03,
       -2.52072699e-02, -2.52072699e-02])

What's alpha?

In [28]:
training_model.model.compute_properties(
    test_molecule,
    as_numpy=True
)["alpha"]

array([1.1623677 , 0.        , 0.        , 0.        , 1.16141546,
       0.        , 0.        , 0.        , 0.        , 0.        ])