In [1]:
"""
Author: TMJ
Date: 2024-05-13 18:12:11
LastEditors: TMJ
LastEditTime: 2024-05-19 15:05:50
Description: 请填写简介
"""

import datetime
import glob
import os
import shutil

import numpy as np
import torch
from dig.threedgraph.evaluation.eval import ThreeDEvaluator
from dig.threedgraph.method.run import run
from nff.io import AtomsBatch
from nff.train import Adam, Trainer, hooks, loss, metrics
from pint import UnitRegistry
from rdkit import Chem
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from qm9star_query.dataset.base_dataset import BaseQM9starDataset
from qm9star_query.dataset.sub_datasets import (
    AnionQM9starDataset,
    CationQM9starDataset,
    NeutralQM9starDataset,
    RadicalQM9starDataset,
)
from qm9star_query.nn.dimenetpp import DimeNetPPCM

ureg = UnitRegistry()
hartree2kcal = ureg.convert(1, "hartree/particle", "kcal/mol")
hartree_bohr2kcal_mol_angstrom = ureg.convert(
    1, "hartree/bohr/particle", "kcal/mol/angstrom"
)
# energy from /tutorial/atom_ref
atom_single_point_energy = {
    "H": -0.5021559 * hartree2kcal,
    "C": -37.7375894 * hartree2kcal,
    "N": -54.4992609 * hartree2kcal,
    "O": -74.9889063 * hartree2kcal,
    "F": -99.7605802 * hartree2kcal,
}


def get_total_atom_energy(atom_list: list[int]):
    res = 0
    for atom in atom_list:
        res += atom_single_point_energy[Chem.Atom(atom).GetSymbol()]
    return res


device = "cuda:0" if torch.cuda.is_available() else "cpu"


def transform_data(raw_data):
    pos = torch.tensor(raw_data["coords"], dtype=torch.float32)
    z = torch.tensor(raw_data["atoms"], dtype=torch.int64)
    energy = torch.tensor(
        raw_data["single_point_energy"] * hartree2kcal
        - get_total_atom_energy(raw_data["atoms"]),
        dtype=torch.float32,
    )
    return Data(
        pos=pos,
        nxyz=torch.concat([z.view(-1, 1), pos], dim=-1),
        formal_charges=torch.tensor(
            raw_data["formal_charges"], dtype=torch.int64),
        formal_num_radicals=torch.tensor(
            raw_data["formal_num_radicals"], dtype=torch.int64
        ),
        energy=energy,
        energy_grad=-torch.tensor(raw_data["forces"], dtype=torch.float32)
        * hartree_bohr2kcal_mol_angstrom,
    )


total_dataset = BaseQM9starDataset(
    user="hxchem",
    password="hxchem",
    server="127.0.0.1",
    port=35432,
    db="qm9star",
    dataset_name="qm9star_all",
    block_num=5,
    log=True,
    transform=transform_data,
)

In [2]:
train_split = 0.8
val_split = 0.1
test_split = 0.1

rng = np.random.default_rng(seed=3407)
ids = list(total_dataset.indices())
rng.shuffle(ids)

train_ids = ids[: int(len(ids) * train_split)]
val_ids = ids[int(len(ids) * train_split) : int(len(ids) * (train_split + val_split))]
test_ids = ids[int(len(ids) * (train_split + val_split)) :]

print(f"Train: {len(train_ids)}, Valid: {len(val_ids)}, test: {len(test_ids)}")

train_dataset = total_dataset[train_ids]
valid_dataset = total_dataset[val_ids]
test_dataset = total_dataset[test_ids]

Train: 1609571, Valid: 201196, test: 201197


In [3]:
model = DimeNetPPCM(energy_and_force=False, ret_res_dict=True)
OUTDIR = "./sandbox"
if os.path.exists(OUTDIR):
    newpath = os.path.join(os.path.dirname(OUTDIR), "backup")
    if os.path.exists(newpath):
        shutil.rmtree(newpath)

    shutil.move(OUTDIR, newpath)

In [4]:
train_loader = DataLoader(
    train_dataset, 128, shuffle=True, follow_batch=["nxyz"])
valid_loader = DataLoader(
    valid_dataset, 128, shuffle=False, follow_batch=["nxyz"])
test_loader = DataLoader(
    test_dataset, 128, shuffle=False, follow_batch=["nxyz"])

In [5]:
for batch in train_loader:
    print(batch["batch"])
    break

tensor([  0,   0,   0,  ..., 127, 127, 127])


In [6]:
loss_fn = loss.build_mae_loss(loss_coef={"energy": 1})

trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = Adam(trainable_params, lr=1e-4)
train_metrics = [
    metrics.MeanAbsoluteError("energy"),
    # metrics.MeanAbsoluteError("energy_grad"),
]
train_hooks = [
    hooks.MaxEpochHook(300),
    hooks.CSVHook(
        OUTDIR,
        metrics=train_metrics,
    ),
    hooks.PrintingHook(
        OUTDIR, metrics=train_metrics, separator=" | ", time_strf="%M:%S"
    ),
    hooks.ReduceLROnPlateauHook(
        optimizer=optimizer,
        patience=30,
        factor=0.5,
        min_lr=1e-7,
        window_length=1,
        stop_after_min=True,
    ),
]
T = Trainer(
    model_path=OUTDIR,
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=valid_loader,
    checkpoint_interval=1,
    hooks=train_hooks,
)

In [7]:
T.train(device=device, n_epochs=400)

 Time | Epoch | Learning rate | Train loss | Validation loss | MAE_energy | GPU Memory (MB)


Please either pass the dim explicitly or simply use torch.linalg.cross.
The default value of dim will change to agree with that of linalg.cross in a future release. (Triggered internally at ../aten/src/ATen/native/Cross.cpp:62.)
  b = torch.cross(pos_ji, pos_jk).norm(dim=-1) # sin_angle * |pos_ji| * |pos_jk|
100%|█████████▉| 12574/12575 [22:30<00:00,  9.31it/s]


31:31 |     1 |     1.000e-04 |    18.3860 |          5.3132 |     5.3146 |            4792


100%|█████████▉| 12574/12575 [20:32<00:00, 10.20it/s]


53:11 |     2 |     1.000e-04 |     7.2004 |          9.5559 |     9.5563 |            4792


100%|█████████▉| 12574/12575 [20:21<00:00, 10.29it/s]


14:41 |     3 |     1.000e-04 |     5.2391 |          5.5432 |     5.5447 |            4861


100%|█████████▉| 12574/12575 [20:22<00:00, 10.29it/s]


36:10 |     4 |     1.000e-04 |     4.3394 |          5.1335 |     5.1326 |            4861


100%|█████████▉| 12574/12575 [20:21<00:00, 10.30it/s]


57:40 |     5 |     1.000e-04 |     3.7286 |          2.4910 |     2.4913 |            4861


100%|█████████▉| 12574/12575 [20:21<00:00, 10.30it/s]


19:07 |     6 |     1.000e-04 |     3.2503 |          3.5656 |     3.5662 |            4861


100%|█████████▉| 12574/12575 [20:23<00:00, 10.28it/s] 


40:38 |     7 |     1.000e-04 |     3.0739 |          3.6992 |     3.6999 |            4861


100%|█████████▉| 12574/12575 [20:19<00:00, 10.31it/s]


02:04 |     8 |     1.000e-04 |     2.8494 |          3.4633 |     3.4625 |            4861


100%|█████████▉| 12574/12575 [25:30<00:00,  8.22it/s]


29:01 |     9 |     1.000e-04 |     2.6640 |          2.0840 |     2.0841 |            4861


100%|█████████▉| 12574/12575 [25:48<00:00,  8.12it/s]


56:19 |    10 |     1.000e-04 |     2.3781 |          1.4193 |     1.4196 |            4861


100%|█████████▉| 12574/12575 [26:01<00:00,  8.05it/s]


23:28 |    11 |     1.000e-04 |     2.3373 |          1.4043 |     1.4045 |            4861


100%|█████████▉| 12574/12575 [22:13<00:00,  9.43it/s]


46:49 |    12 |     1.000e-04 |     2.2173 |          1.9574 |     1.9571 |            4861


100%|█████████▉| 12574/12575 [20:13<00:00, 10.37it/s]


08:07 |    13 |     1.000e-04 |     2.0534 |          2.6752 |     2.6752 |            4861


100%|█████████▉| 12574/12575 [19:54<00:00, 10.53it/s]


29:07 |    14 |     1.000e-04 |     1.9961 |          2.0678 |     2.0679 |            4861


100%|█████████▉| 12574/12575 [20:07<00:00, 10.41it/s]


50:22 |    15 |     1.000e-04 |     1.8757 |          1.0393 |     1.0396 |            4894


100%|█████████▉| 12574/12575 [20:03<00:00, 10.45it/s]


11:32 |    16 |     1.000e-04 |     1.8548 |          1.7518 |     1.7519 |            4894


100%|█████████▉| 12574/12575 [20:09<00:00, 10.39it/s]


32:49 |    17 |     1.000e-04 |     1.7844 |          1.6599 |     1.6602 |            4894


100%|█████████▉| 12574/12575 [20:25<00:00, 10.26it/s]


54:20 |    18 |     1.000e-04 |     1.6844 |          1.3121 |     1.3124 |            4894


100%|█████████▉| 12574/12575 [20:25<00:00, 10.26it/s]


15:56 |    19 |     1.000e-04 |     1.6317 |          1.6805 |     1.6806 |            4894


100%|█████████▉| 12574/12575 [20:28<00:00, 10.23it/s]


37:32 |    20 |     1.000e-04 |     1.5832 |          1.1620 |     1.1622 |            4894


100%|█████████▉| 12574/12575 [20:25<00:00, 10.26it/s]


59:05 |    21 |     1.000e-04 |     1.5142 |          1.8573 |     1.8569 |            4894


100%|█████████▉| 12574/12575 [20:24<00:00, 10.27it/s]


20:36 |    22 |     1.000e-04 |     1.4807 |          1.5073 |     1.5075 |            4894


100%|█████████▉| 12574/12575 [20:27<00:00, 10.25it/s]


42:10 |    23 |     1.000e-04 |     1.4434 |          0.8954 |     0.8957 |            4894


100%|█████████▉| 12574/12575 [20:31<00:00, 10.21it/s] 


03:50 |    24 |     1.000e-04 |     1.4195 |          0.8045 |     0.8048 |            4894


100%|█████████▉| 12574/12575 [20:38<00:00, 10.16it/s]


25:35 |    25 |     1.000e-04 |     1.3785 |          2.0060 |     2.0059 |            4894


100%|█████████▉| 12574/12575 [20:40<00:00, 10.14it/s]


47:23 |    26 |     1.000e-04 |     1.3361 |          2.6665 |     2.6669 |            4894


100%|█████████▉| 12574/12575 [20:37<00:00, 10.16it/s]


09:08 |    27 |     1.000e-04 |     1.3444 |          1.0219 |     1.0224 |            4895


100%|█████████▉| 12574/12575 [20:38<00:00, 10.15it/s]


30:55 |    28 |     1.000e-04 |     1.3014 |          0.9993 |     0.9995 |            4895


100%|█████████▉| 12574/12575 [20:28<00:00, 10.24it/s]


52:33 |    29 |     1.000e-04 |     1.2381 |          0.8642 |     0.8645 |            4895


100%|█████████▉| 12574/12575 [20:22<00:00, 10.29it/s]


14:02 |    30 |     1.000e-04 |     1.2377 |          0.7929 |     0.7931 |            4895


100%|█████████▉| 12574/12575 [20:26<00:00, 10.25it/s]


35:35 |    31 |     1.000e-04 |     1.2212 |          0.7594 |     0.7597 |            4895


100%|█████████▉| 12574/12575 [20:24<00:00, 10.26it/s]


57:08 |    32 |     1.000e-04 |     1.2024 |          0.7771 |     0.7774 |            4895


100%|█████████▉| 12574/12575 [20:23<00:00, 10.27it/s]


18:38 |    33 |     1.000e-04 |     1.1639 |          0.7579 |     0.7581 |            4895


100%|█████████▉| 12574/12575 [20:24<00:00, 10.27it/s]


40:10 |    34 |     1.000e-04 |     1.1392 |          0.7265 |     0.7267 |            4895


100%|█████████▉| 12574/12575 [20:28<00:00, 10.24it/s]


01:44 |    35 |     1.000e-04 |     1.1312 |          0.9224 |     0.9224 |            4895


100%|█████████▉| 12574/12575 [20:25<00:00, 10.26it/s]


23:16 |    36 |     1.000e-04 |     1.1108 |          0.8773 |     0.8774 |            4895


100%|█████████▉| 12574/12575 [20:23<00:00, 10.27it/s]


44:48 |    37 |     1.000e-04 |     1.0901 |          1.5021 |     1.5018 |            4895


100%|█████████▉| 12574/12575 [20:24<00:00, 10.27it/s]


06:18 |    38 |     1.000e-04 |     1.0722 |          1.3079 |     1.3082 |            4895


100%|█████████▉| 12574/12575 [20:26<00:00, 10.25it/s]


27:53 |    39 |     1.000e-04 |     1.0534 |          1.7409 |     1.7410 |            4895


100%|█████████▉| 12574/12575 [20:25<00:00, 10.26it/s]


49:25 |    40 |     1.000e-04 |     1.0465 |          0.6650 |     0.6652 |            4895


100%|█████████▉| 12574/12575 [20:28<00:00, 10.23it/s] 


11:01 |    41 |     1.000e-04 |     1.0432 |          0.9488 |     0.9489 |            4895


100%|█████████▉| 12574/12575 [20:22<00:00, 10.28it/s]


32:30 |    42 |     1.000e-04 |     0.9976 |          1.4518 |     1.4516 |            4895


100%|█████████▉| 12574/12575 [20:22<00:00, 10.28it/s]


54:00 |    43 |     1.000e-04 |     1.0127 |          0.6030 |     0.6032 |            4895


100%|█████████▉| 12574/12575 [20:23<00:00, 10.28it/s]


15:31 |    44 |     1.000e-04 |     0.9863 |          2.0577 |     2.0575 |            4895


100%|█████████▉| 12574/12575 [20:24<00:00, 10.27it/s]


37:02 |    45 |     1.000e-04 |     0.9802 |          0.6487 |     0.6489 |            4895


100%|█████████▉| 12574/12575 [20:23<00:00, 10.28it/s]


58:35 |    46 |     1.000e-04 |     0.9777 |          0.8079 |     0.8082 |            4895


100%|█████████▉| 12574/12575 [20:33<00:00, 10.19it/s]


20:17 |    47 |     1.000e-04 |     0.9269 |          1.0591 |     1.0591 |            4895


100%|█████████▉| 12574/12575 [20:32<00:00, 10.20it/s]


41:58 |    48 |     1.000e-04 |     0.9410 |          0.6849 |     0.6851 |            4895


  1%|          | 93/12575 [00:09<21:35,  9.63it/s]


KeyboardInterrupt: 