In [1]:
import os
import schnetpack as spk
from schnetpack.datasets import QM9
import schnetpack.transform as trn

import torch
import torchmetrics
import pytorch_lightning as pl

qm9tut = './qm9tut'
if not os.path.exists('qm9tut'):
    os.makedirs(qm9tut)

In [2]:
DB_PATH = "./qm9.db"
PROPERTY = QM9.lumo
BATCH_SIZE = 16
NUM_TRAIN = 110000
NUM_VALIDATION = 10000
CUTOFF = 5.
N_ATOM_BASIS = 32
T = 3
EPOCHS = 1
LR = 1e-4
NUM_WORKERS = 1
PIN_MEMORY = True
FREEZE = True

In [3]:
qm9data = QM9(
    DB_PATH,
    batch_size=BATCH_SIZE,
    num_train=NUM_TRAIN,
    num_val=NUM_VALIDATION,
    transforms=[
        trn.ASENeighborList(cutoff=float(CUTOFF)),
        # trn.RemoveOffsets(PROPERTY, remove_mean=True, remove_atomrefs=True),
        trn.CastTo32()
    ],
    num_workers=NUM_WORKERS,
    split_file=os.path.join(qm9tut, "split.npz"),
    pin_memory=PIN_MEMORY, # set to false, when not using a GPU
    load_properties=[PROPERTY], 
)
qm9data.prepare_data()
qm9data.setup()

In [4]:
pairwise_distance = spk.atomistic.PairwiseDistances() # calculates pairwise distances between atoms
radial_basis = spk.nn.GaussianRBF(n_rbf=20, cutoff=CUTOFF)
schnet = spk.representation.SchNet(
    n_atom_basis=N_ATOM_BASIS, n_interactions=T,
    radial_basis=radial_basis,
    cutoff_fn=spk.nn.CosineCutoff(CUTOFF)
)
pred_property = spk.atomistic.Atomwise(n_in=N_ATOM_BASIS, output_key=PROPERTY)
nnpot = spk.model.NeuralNetworkPotential(
    representation=schnet,
    input_modules=[pairwise_distance],
    output_modules=[pred_property],
    postprocessors=[trn.CastTo64()]
)

In [5]:
pretrained_model = torch.load(os.path.join('weights', 'best_homo_e50.pt'))

# Set missing keys in state_dict

nnpot.load_state_dict(pretrained_model.state_dict())

if FREEZE:
    for name, param in nnpot.named_parameters():
        if 'output' not in name:    # freeze all but the output layers
            param.requires_grad = False


In [6]:
output_property = spk.task.ModelOutput(
    name=PROPERTY,
    loss_fn=torch.nn.MSELoss(),
    loss_weight=1.,
    metrics={
        "MAE": torchmetrics.MeanAbsoluteError()
    }
)

In [7]:
task = spk.task.AtomisticTask(
    model=nnpot,
    outputs=[output_property],
    optimizer_cls=torch.optim.AdamW,
    optimizer_args={"lr": LR}
)

/home/aimas/dtu/dl/DeepLearningProject/.venv/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.


In [8]:
logger = pl.loggers.TensorBoardLogger(save_dir=qm9tut)
callbacks = [
    spk.train.ModelCheckpoint(
        model_path=os.path.join(qm9tut, "best_inference_model"),
        save_top_k=1,
        monitor="val_loss"
    )
]

trainer = pl.Trainer(
    callbacks=callbacks,
    logger=logger,
    default_root_dir=qm9tut,
    max_epochs=EPOCHS, # for testing, we restrict the number of epochs
)
trainer.fit(task, datamodule=qm9data)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3050 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type                   | Params
---------------------------------------------------
0 | model   | NeuralNetworkPotential | 18.3 K
1 | outputs | ModuleList             | 0     
---------------------------------------------------
545       Trainable params
17.8 K    Non-trainable params
18.3 K    Total params
0.073     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/aimas/dtu/dl/DeepLearningProject/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

/home/aimas/dtu/dl/DeepLearningProject/.venv/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 16. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/aimas/dtu/dl/DeepLearningProject/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 6875/6875 [09:28<00:00, 12.08it/s, v_num=13, val_loss=0.000545]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 6875/6875 [09:28<00:00, 12.08it/s, v_num=13, val_loss=0.000545]
