In [1]:
import os
import schnetpack as spk
from schnetpack.datasets import QM9
import schnetpack.transform as trn
import numpy as np

import torch
import torchmetrics
import pytorch_lightning as pl

qm9tut = './qm9tut'
if not os.path.exists('qm9tut'):
    os.makedirs(qm9tut)


**NOTE:** Set batch size to 1!

In [2]:
# %rm split.npz

DB_PATH = "./qm9.db"
PROPERTY = QM9.homo
PROPERTIES = [PROPERTY]
BATCH_SIZE = 1
NUM_TRAIN = 110000
NUM_VALIDATION = 10000
CUTOFF = 5.
N_ATOM_BASIS = 32
T = 3
EPOCHS = 3
LR = 1e-4
NUM_WORKERS = 1
PIN_MEMORY = True

torch.manual_seed(0)

<torch._C.Generator at 0x7f2aff9230f0>

In [3]:
qm9data = QM9(
    DB_PATH,
    batch_size=BATCH_SIZE,
    num_train=NUM_TRAIN,
    num_val=NUM_VALIDATION,
    transforms=[
        trn.ASENeighborList(cutoff=float(CUTOFF)),
        trn.CastTo32()
    ],
    num_workers=NUM_WORKERS,
    split_file=os.path.join(qm9tut, "split.npz"),
    pin_memory=PIN_MEMORY, # set to false, when not using a GPU
    load_properties=PROPERTIES, #only load U0 property
)
qm9data.prepare_data()
qm9data.setup()

In [4]:
from data_handler import QM9DataHandler

dh = QM9DataHandler(qm9data)

dh.fetch_data(PROPERTIES)

100%|██████████| 10000/10000 [01:44<00:00, 95.77it/s]


To create a filter to more easily extract individual atoms from the embeddings, use the `set_atom_isolation`-function. `QM9DataHandler` will work like an iterable and return all relevant data in a `dict`-format. The `dict` will contain the following keys:
- `positions`: The positions of the atoms in the molecule
- `atom_numbers`: The atomic numbers in sequence in the given molecule.
- `atom_mask`: A mask to indicate the positions of the chosen atom in the molecule.
- `properties`: The propery values of the molecule.

In [5]:
dh.set_atom_isolation(QM9DataHandler.C)

for data in dh:
    print(data)
    break

{'positions': array([[ 0.0252458 ,  1.4970046 ,  0.08615518],
       [ 1.3247613 ,  0.7228304 ,  0.00887963],
       [ 2.0554478 ,  0.34365776,  1.2268586 ],
       [ 1.124637  , -0.4352904 ,  1.9408693 ],
       [-0.02777087, -0.74836653,  1.2716168 ],
       [-0.7649887 , -1.6382285 ,  2.10268   ],
       [-0.02362859, -1.8164945 ,  3.2514277 ],
       [ 1.145672  , -1.0650512 ,  3.1468434 ],
       [-0.01373413, -0.01781819, -0.02538622],
       [-0.2730558 ,  1.9381217 ,  1.0315485 ],
       [-0.231681  ,  2.078949  , -0.7932714 ],
       [ 1.9452933 ,  0.8327048 , -0.8731382 ],
       [ 2.4170675 ,  1.1480627 ,  1.7297515 ],
       [-1.7246846 , -2.091034  ,  1.9050224 ],
       [-0.22048482, -2.3963943 ,  4.1378865 ],
       [ 1.8915768 , -1.0547559 ,  3.8193238 ],
       [-0.37758958, -0.47303197, -0.93970454]], dtype=float32), 'atom_numbers': array([6, 6, 7, 6, 6, 6, 6, 7, 6, 1, 1, 1, 1, 1, 1, 1, 1]), 'properties': array([-0.1629]), 'atom_mask': array([ True,  True, False,  Tru

When data is fetch, you can parse a trained model to the `fetch_model_output`-function. This will return the output of the model for the given data. The output will update the previous `dict`, such that it will contain the following keys:
- `positions`: The positions of the atoms in the molecule
- `atom_numbers`: The atomic numbers in sequence in the given molecule.
- `atom_mask`: A mask to indicate the positions of the chosen atom in the molecule.
- `properties`: The propery values of the molecule.
- `embeddings`: The embeddings of the atoms in the molecule.
- `predictions`: The output of the model for the given molecule.

In [6]:
model = torch.load("./best_homo_e50.pt", map_location=torch.device('cpu'))
dh.fetch_model_outputs(model)

100%|██████████| 10000/10000 [05:01<00:00, 33.18it/s]


In [7]:
for data in dh:
    print(data)
    break

{'positions': array([[ 0.0252458 ,  1.4970046 ,  0.08615518],
       [ 1.3247613 ,  0.7228304 ,  0.00887963],
       [ 2.0554478 ,  0.34365776,  1.2268586 ],
       [ 1.124637  , -0.4352904 ,  1.9408693 ],
       [-0.02777087, -0.74836653,  1.2716168 ],
       [-0.7649887 , -1.6382285 ,  2.10268   ],
       [-0.02362859, -1.8164945 ,  3.2514277 ],
       [ 1.145672  , -1.0650512 ,  3.1468434 ],
       [-0.01373413, -0.01781819, -0.02538622],
       [-0.2730558 ,  1.9381217 ,  1.0315485 ],
       [-0.231681  ,  2.078949  , -0.7932714 ],
       [ 1.9452933 ,  0.8327048 , -0.8731382 ],
       [ 2.4170675 ,  1.1480627 ,  1.7297515 ],
       [-1.7246846 , -2.091034  ,  1.9050224 ],
       [-0.22048482, -2.3963943 ,  4.1378865 ],
       [ 1.8915768 , -1.0547559 ,  3.8193238 ],
       [-0.37758958, -0.47303197, -0.93970454]], dtype=float32), 'atom_numbers': array([6, 6, 7, 6, 6, 6, 6, 7, 6, 1, 1, 1, 1, 1, 1, 1, 1]), 'properties': array([-0.1629]), 'atom_mask': array([ True,  True, False,  Tru