In [None]:
"""
Author: TMJ
Date: 2024-05-13 18:12:11
LastEditors: TMJ
LastEditTime: 2024-05-19 15:05:50
Description: 请填写简介
"""

import datetime
import glob
import os
import shutil

import numpy as np
import torch
from dig.threedgraph.evaluation.eval import ThreeDEvaluator
from dig.threedgraph.method.run import run
from nff.io import AtomsBatch
from nff.train import Adam, Trainer, hooks, loss, metrics
from pint import UnitRegistry
from rdkit import Chem
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from qm9star_query.dataset.base_dataset import BaseQM9starDataset
from qm9star_query.dataset.sub_datasets import (
    AnionQM9starDataset,
    CationQM9starDataset,
    NeutralQM9starDataset,
    RadicalQM9starDataset,
)
from qm9star_query.nn.dimenetpp import DimeNetPPCM

ureg = UnitRegistry()
hartree2kcal = ureg.convert(1, "hartree/particle", "kcal/mol")
hartree_bohr2kcal_mol_angstrom = ureg.convert(
    1, "hartree/bohr/particle", "kcal/mol/angstrom"
)
# energy from /tutorial/atom_ref
atom_single_point_energy = {
    "H": -0.5021559 * hartree2kcal,
    "C": -37.7375894 * hartree2kcal,
    "N": -54.4992609 * hartree2kcal,
    "O": -74.9889063 * hartree2kcal,
    "F": -99.7605802 * hartree2kcal,
}


def get_total_atom_energy(atom_list: list[int]):
    res = 0
    for atom in atom_list:
        res += atom_single_point_energy[Chem.Atom(atom).GetSymbol()]
    return res


device = "cuda:0" if torch.cuda.is_available() else "cpu"


def transform_data(raw_data):
    pos = torch.tensor(raw_data["coords"], dtype=torch.float32)
    z = torch.tensor(raw_data["atoms"], dtype=torch.int64)
    energy = torch.tensor(
        raw_data["single_point_energy"] * hartree2kcal
        - get_total_atom_energy(raw_data["atoms"]),
        dtype=torch.float32,
    )
    return Data(
        pos=pos,
        z=z,
        bonds=torch.tensor(raw_data["bonds"], dtype=torch.int64),
        nxyz=torch.concat([z.view(-1, 1), pos], dim=-1),
        formal_charges=torch.tensor(raw_data["formal_charges"], dtype=torch.int64),
        formal_num_radicals=torch.tensor(
            raw_data["formal_num_radicals"], dtype=torch.int64
        ),
        energy=energy,
        energy_grad=-torch.tensor(raw_data["forces"], dtype=torch.float32)
        * hartree_bohr2kcal_mol_angstrom,
    )


total_dataset = BaseQM9starDataset(
    user="hxchem",
    password="hxchem",
    server="127.0.0.1",
    port=35432,
    db="qm9star",
    dataset_name="qm9star_all",
    block_num=5,
    log=True,
    transform=transform_data,
)

In [None]:
train_split = 0.8
val_split = 0.1
test_split = 0.1

rng = np.random.default_rng(seed=3407)
ids = list(total_dataset.indices())
rng.shuffle(ids)

train_ids = ids[: int(len(ids) * train_split)]
val_ids = ids[int(len(ids) * train_split) : int(len(ids) * (train_split + val_split))]
test_ids = ids[int(len(ids) * (train_split + val_split)) :]

print(f"Train: {len(train_ids)}, Valid: {len(val_ids)}, test: {len(test_ids)}")

train_dataset = total_dataset[train_ids]
valid_dataset = total_dataset[val_ids]
test_dataset = total_dataset[test_ids]

In [None]:
model = DimeNetPPCM(energy_and_force=False, ret_res_dict=True)
OUTDIR = "./sandbox"
if os.path.exists(OUTDIR):
    newpath = os.path.join(os.path.dirname(OUTDIR), "backup")
    if os.path.exists(newpath):
        shutil.rmtree(newpath)

    shutil.move(OUTDIR, newpath)

In [None]:
train_loader = DataLoader(
    train_dataset, 128, shuffle=True, follow_batch=["nxyz"])
valid_loader = DataLoader(
    valid_dataset, 128, shuffle=False, follow_batch=["nxyz"])
test_loader = DataLoader(
    test_dataset, 128, shuffle=False, follow_batch=["nxyz"])

In [None]:
for batch in train_loader:
    print(batch["batch"])
    break

In [None]:
loss_fn = loss.build_mae_loss(loss_coef={"energy": 1})

trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = Adam(trainable_params, lr=1e-4)
train_metrics = [
    metrics.MeanAbsoluteError("energy"),
    # metrics.MeanAbsoluteError("energy_grad"),
]
train_hooks = [
    hooks.MaxEpochHook(300),
    hooks.CSVHook(
        OUTDIR,
        metrics=train_metrics,
    ),
    hooks.PrintingHook(
        OUTDIR, metrics=train_metrics, separator=" | ", time_strf="%M:%S"
    ),
    hooks.ReduceLROnPlateauHook(
        optimizer=optimizer,
        patience=30,
        factor=0.5,
        min_lr=1e-7,
        window_length=1,
        stop_after_min=True,
    ),
]
T = Trainer(
    model_path=OUTDIR,
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=valid_loader,
    checkpoint_interval=1,
    hooks=train_hooks,
)

In [None]:
T.train(device=device, n_epochs=400)