# Prepare a NAGL dataset for training

Training a GCN requires a collection of examples that the GCN should reproduce and interpolate between. This notebook describes how to prepare such a dataset for predicting partial charges.

## Imports

In [1]:
from pathlib import Path

from tqdm import tqdm

from openff.toolkit.topology import Molecule

from openff.nagl.label.dataset import LabelledDataset
from openff.nagl.label.labels import LabelCharges



## Choosing our molecules

The simplest way to specify the molecules in our dataset is with SMILES, though [anything you can load](https://docs.openforcefield.org/projects/toolkit/en/stable/users/molecule_cookbook.html) into an OpenFF [`Molecule`] is fair game. For instance, with the [`Molecule.from_file()`] method you could load partial charges from SDF files. But for this example, we'll have NAGL generate our charges, so we can just provide the SMILES themselves:

[`Molecule`]: https://docs.openforcefield.org/projects/toolkit/en/stable/api/generated/openff.toolkit.topology.Molecule.html
[`Molecule.from_file()`]: https://docs.openforcefield.org/projects/toolkit/en/stable/api/generated/openff.toolkit.topology.Molecule.html#openff.toolkit.topology.Molecule.from_file

In [2]:
alkanes_smiles = Path("alkanes.smi").read_text().splitlines()
alkanes_smiles

['C',
 'CC',
 'CCC',
 'CCCC',
 'CC(C)C',
 'CCCCC',
 'CC(C)CC',
 'CCCCCC',
 'CC(C)CCC',
 'CC(CC)CC']

## Generating a LabelledDataset

A LabelledDataset is a wrapper over an [Apache Arrow Dataset](https://arrow.apache.org/docs/python/api/dataset.html) to make it easy to generate data. When we train GNN models, the data is read directly as an Arrow dataset, so there is no need to use a `LabelledDataset` to generate your data other than convenience. Here we demonstrate the conveniences of a `LabelledDataset`.

In [3]:
dataset = LabelledDataset.from_smiles(
    "labelled_alkanes",  # path to save to
    alkanes_smiles,
    mapped=False,
    overwrite_existing=True,
)
dataset.to_pandas()

Unnamed: 0,mapped_smiles
0,[H:2][C:1]([H:3])([H:4])[H:5]
1,[H:3][C:1]([H:4])([H:5])[C:2]([H:6])([H:7])[H:8]
2,[H:4][C:1]([H:5])([H:6])[C:2]([H:7])([H:8])[C:...
3,[H:5][C:1]([H:6])([H:7])[C:2]([H:8])([H:9])[C:...
4,[H:5][C:1]([H:6])([H:7])[C:2]([H:8])([C:3]([H:...
5,[H:6][C:1]([H:7])([H:8])[C:2]([H:9])([H:10])[C...
6,[H:15][C:5]([H:16])([H:17])[C:4]([H:13])([H:14...
7,[H:7][C:1]([H:8])([H:9])[C:2]([H:10])([H:11])[...
8,[H:18][C:6]([H:19])([H:20])[C:5]([H:16])([H:17...
9,[H:13][C:4]([H:14])([H:15])[C:3]([H:11])([H:12...


In [4]:
# path to directory containing parquet files of dataset
dataset.source

'labelled_alkanes'

In [5]:
# actual files of the dataset
dataset.dataset.files

['labelled_alkanes/part-0.parquet']

## Generating charges

NAGL can generate AM1-BCC and AM1-Mulliken charges automatically with the OpenFF Toolkit.
`exist_ok` defines whether to error if the `charge_column` is already present in the dataset.
Normally we want this to be `False`, but it's set to `True` here in case we run the cell
multiple times.

In [6]:
am1bcc_labeller = LabelCharges(
    charge_method="am1bcc",
    charge_column="am1bcc_charges",
    exist_ok=True,
)
am1_labeller = LabelCharges(
    charge_method="am1-mulliken",
    charge_column="am1_charges",
    exist_ok=True,
)
    

dataset.apply_labellers(
    [am1_labeller, am1bcc_labeller],
    verbose=True,
)
dataset.to_pandas()

Applying labellers to batches: 0it [00:00, ?it/s]
Assigning charges:   0%|                                 | 0/10 [00:00<?, ?it/s][A
Assigning charges:  10%|██▌                      | 1/10 [00:00<00:03,  2.58it/s][A
Assigning charges:  30%|███████▌                 | 3/10 [00:00<00:01,  6.44it/s][A
Assigning charges:  50%|████████████▌            | 5/10 [00:00<00:00,  8.50it/s][A
Assigning charges:  70%|█████████████████▌       | 7/10 [00:00<00:00,  9.86it/s][A
Assigning charges: 100%|████████████████████████| 10/10 [00:01<00:00,  8.33it/s][A

Assigning charges:   0%|                                 | 0/10 [00:00<?, ?it/s][A
Assigning charges:  20%|█████                    | 2/10 [00:00<00:00, 14.19it/s][A
Assigning charges:  40%|██████████               | 4/10 [00:00<00:00, 13.98it/s][A
Assigning charges:  60%|███████████████          | 6/10 [00:00<00:00, 13.43it/s][A
Assigning charges:  80%|████████████████████     | 8/10 [00:00<00:00, 12.51it/s][A
Assigning charges: 100%|█

Unnamed: 0,mapped_smiles,am1_charges,am1bcc_charges
0,[H:2][C:1]([H:3])([H:4])[H:5],"[-0.2658799886703491, 0.06646999716758728, 0.0...","[-0.10868000239133835, 0.027170000597834587, 0..."
1,[H:3][C:1]([H:4])([H:5])[C:2]([H:6])([H:7])[H:8],"[-0.21174000017344952, -0.21174000017344952, 0...","[-0.09384000208228827, -0.09384000208228827, 0..."
2,[H:4][C:1]([H:5])([H:6])[C:2]([H:7])([H:8])[C:...,"[-0.21018000082536178, -0.15999999777837234, -...","[-0.09227999977090141, -0.08139999888160011, -..."
3,[H:5][C:1]([H:6])([H:7])[C:2]([H:8])([H:9])[C:...,"[-0.21003000438213348, -0.15905000269412994, -...","[-0.09212999844125339, -0.08044999891093799, -..."
4,[H:5][C:1]([H:6])([H:7])[C:2]([H:8])([C:3]([H:...,"[-0.20747000138674462, -0.10981000374470438, -...","[-0.08957000076770782, -0.07050999999046326, -..."
5,[H:6][C:1]([H:7])([H:8])[C:2]([H:9])([H:10])[C...,"[-0.21004000306129456, -0.15812000632286072, -...","[-0.09213999658823013, -0.07952000200748444, -..."
6,[H:15][C:5]([H:16])([H:17])[C:4]([H:13])([H:14...,"[-0.20766000405830495, -0.10704000250381582, -...","[-0.0897599982426447, -0.06774000100353185, -0..."
7,[H:7][C:1]([H:8])([H:9])[C:2]([H:10])([H:11])[...,"[-0.21021999344229697, -0.15823000594973563, -...","[-0.0923200011253357, -0.0796300008893013, -0...."
8,[H:18][C:6]([H:19])([H:20])[C:5]([H:16])([H:17...,"[-0.208649992197752, -0.1059999980032444, -0.2...","[-0.09075000137090683, -0.06669999659061432, -..."
9,[H:13][C:4]([H:14])([H:15])[C:3]([H:11])([H:12...,"[-0.2068299949169159, -0.10380999743938446, -0...","[-0.08893000297248363, -0.06451000235974788, -..."


If you have your own charges to add, use the `LabelledDataset.append_columns`. **Warning: this does not run any checks as to the validity of the charges, such as the length or type!**

In [7]:
dataset.append_columns(
    columns={
        "custom_charges": [
            [i]
            for i in range(len(alkanes_smiles))
        ]
    }
)
dataset.to_pandas()

Unnamed: 0,mapped_smiles,am1_charges,am1bcc_charges,custom_charges
0,[H:2][C:1]([H:3])([H:4])[H:5],"[-0.2658799886703491, 0.06646999716758728, 0.0...","[-0.10868000239133835, 0.027170000597834587, 0...",[0]
1,[H:3][C:1]([H:4])([H:5])[C:2]([H:6])([H:7])[H:8],"[-0.21174000017344952, -0.21174000017344952, 0...","[-0.09384000208228827, -0.09384000208228827, 0...",[1]
2,[H:4][C:1]([H:5])([H:6])[C:2]([H:7])([H:8])[C:...,"[-0.21018000082536178, -0.15999999777837234, -...","[-0.09227999977090141, -0.08139999888160011, -...",[2]
3,[H:5][C:1]([H:6])([H:7])[C:2]([H:8])([H:9])[C:...,"[-0.21003000438213348, -0.15905000269412994, -...","[-0.09212999844125339, -0.08044999891093799, -...",[3]
4,[H:5][C:1]([H:6])([H:7])[C:2]([H:8])([C:3]([H:...,"[-0.20747000138674462, -0.10981000374470438, -...","[-0.08957000076770782, -0.07050999999046326, -...",[4]
5,[H:6][C:1]([H:7])([H:8])[C:2]([H:9])([H:10])[C...,"[-0.21004000306129456, -0.15812000632286072, -...","[-0.09213999658823013, -0.07952000200748444, -...",[5]
6,[H:15][C:5]([H:16])([H:17])[C:4]([H:13])([H:14...,"[-0.20766000405830495, -0.10704000250381582, -...","[-0.0897599982426447, -0.06774000100353185, -0...",[6]
7,[H:7][C:1]([H:8])([H:9])[C:2]([H:10])([H:11])[...,"[-0.21021999344229697, -0.15823000594973563, -...","[-0.0923200011253357, -0.0796300008893013, -0....",[7]
8,[H:18][C:6]([H:19])([H:20])[C:5]([H:16])([H:17...,"[-0.208649992197752, -0.1059999980032444, -0.2...","[-0.09075000137090683, -0.06669999659061432, -...",[8]
9,[H:13][C:4]([H:14])([H:15])[C:3]([H:11])([H:12...,"[-0.2068299949169159, -0.10380999743938446, -0...","[-0.08893000297248363, -0.06451000235974788, -...",[9]
