In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import warnings 
warnings.filterwarnings("ignore")

In [2]:
import sys
import os

sys.path.append("../")

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers import MLFlowLogger
from torch_geometric.loader import DataLoader
from argparse import ArgumentParser

from ConfRankPlus.training.lightning import LightningWrapper
from ConfRankPlus.data.dataset import PairDataset, NewMolecularDataset
from ConfRankPlus.inference.radius_graph import SmartRadiusGraph
from ConfRankPlus.inference.loading import load_ConfRankPlus
from ConfRankPlus.model import ConfRankPlus


## 输入文件信息
data = {
    'confid':构象id，用来区分构象，str,
    'ensbid':分子id，一个分子的不同构象要保持ensbid统一，confid不同,
    'energy':torch.tensor(energy, dtype=torch.float32),
    'total_charge':torch.tensor(charge, dtype=torch.float32),
    'z':torch.tensor(symbol_list, dtype=torch.long),
    'pos':torch.tensor(positions, dtype=torch.float32),
}

# Load Model

In [3]:
device = torch.device('cuda:0')
dtype = torch.float32
compute_forces = False

old_model, fidelity_mapping = load_ConfRankPlus(device=device,
                                                dtype=dtype,
                                                compute_forces=compute_forces)
# new
model = ConfRankPlus(
    hidden_channels=128, 
    num_blocks=2, 
    int_emb_size=64,
    out_emb_channels=96,
    pair_basis_dim=16, 
    triplet_basis_dim=16, 
    cutoff=6.5, 
    cutoff_threebody=4.0, 
    additive_repulsion_energy=True,
    dataset_encoding_dim = 2,
    num_dataset_embeddings = 5,
)
static_dict = old_model.state_dict()
model.load_state_dict(static_dict)

<All keys matched successfully>

In [19]:
energy_loss_fn = lambda x, y: torch.nn.functional.l1_loss(x, y)
lightning_module = LightningWrapper(
    model=model,
    energy_key='energy',
    forces_key=None,
    forces_tradeoff=0.0,
    atomic_numbers_key="z",
    decay_factor=0.5,
    decay_patience=3,
    energy_loss_fn=energy_loss_fn,
    weight_decay=1E-8,
    xy_lim=None,
    pairwise=True,
)

In [20]:
checkpoint = torch.load('Data/Kwon_2000/FFTsGuess/epoch=0-step=396.ckpt', weights_only=False)
lightning_module.load_state_dict(checkpoint["state_dict"], strict=False)

<All keys matched successfully>

# Load Dataset

In [5]:
project_name = 'Data/Kwon_2000/cfrk_new'

In [6]:
train_file = f"{project_name}/train.pt"
val_file = f"{project_name}/val.pt"
test_file = f"{project_name}/test.pt"
radius_graph_transform = SmartRadiusGraph(radius=model.cutoff)
trainset = PairDataset(torch.load(train_file, weights_only=False), 
                       transform=radius_graph_transform)
valset = PairDataset(torch.load(val_file, weights_only=False), 
                     transform=radius_graph_transform)
testset = PairDataset(torch.load(test_file, weights_only=False), 
                      transform=radius_graph_transform)

Using torch_cluster for computing neighborlists.
Calculating ensembles ...
Calculating ensembles ...
Calculating ensembles ...


In [7]:
len(trainset), len(valset), len(testset)

(28404, 3449, 3211)

In [21]:
batch_size = 150
num_workers = 1
train_loader = DataLoader(
    trainset,
    batch_size=batch_size,
    shuffle=True,
)
val_loader = DataLoader(
    valset,
    batch_size=batch_size,
    shuffle=False,
)
test_loader = DataLoader(
    testset,
    batch_size=batch_size,
    shuffle=False,
)


In [22]:
monitor_metric = f"ptl/val_loss_pairwise"
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor=monitor_metric, 
    save_top_k=1,
    dirpath=f'Data/{project_name}'
)
early_stop_callback = EarlyStopping(
    monitor=monitor_metric,
    min_delta=0.0,
    patience=20,
    verbose=True,
    mode="min",
)
lr_monitor = LearningRateMonitor(logging_interval="epoch")
callbacks = [checkpoint_callback, early_stop_callback, lr_monitor]
trainer = pl.Trainer(
    max_epochs=100,
    enable_progress_bar=True,
    callbacks=callbacks,
    logger=True,
    log_every_n_steps=200,
    accelerator="gpu" if torch.cuda.is_available() else 'auto',
    # accelerator = 'cpu', 
    devices=[0],
    precision=32,
    inference_mode=True,
    # allow inference mode but only if no force computation is done. For force computation, inference mode must be False,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [23]:
trainer.fit(
    lightning_module, train_dataloaders=train_loader, val_dataloaders=val_loader
)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type         | Params | Mode 
-----------------------------------------------
0 | model | ConfRankPlus | 471 K  | train
-----------------------------------------------
470 K     Trainable params
624       Non-trainable params
471 K     Total params
1.886     Total estimated model params size (MB)
155       Modules in train mode
0         Modules in eval mode


Epoch 0:  69%|██████▉   | 131/190 [00:31<00:14,  4.18it/s, v_num=2, ptl/train_loss_pairwise_step=2.800]  


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1

In [34]:
torch.cuda.empty_cache()

# Pred Test

In [58]:
# test_file = f'Data/{project_name}/test.pt'
test_file = f'Data/Kwon_Chiral2/Kwon_Chiral2.pt'
temp_testset = NewMolecularDataset('./temp', torch.load(test_file, weights_only=False), 
                                    transform=radius_graph_transform)
temp_test_loader = DataLoader(
    temp_testset,
    batch_size=100,
    shuffle=False,
)

In [59]:
final_energies = {}
with torch.jit.optimized_execution(False):
    for each_batch in temp_test_loader:
        predict_result = lightning_module.model.forward(each_batch)['energy'].detach().numpy().tolist()
        for ensbid, confid, result in zip(each_batch.ensbid, each_batch.confid, predict_result):
            final_energies[f"{ensbid}_{confid}"] = result
        # final_energies += predict_result
        # break

In [41]:
final_energies = []
with torch.jit.optimized_execution(False):
    for each_batch in temp_test_loader:
        data = each_batch

        model_input_dict = dict(pos=data['pos'],
                                z=data['z'].long(),
                                edge_index=data['edge_index'],
                                total_charge=data['total_charge'],
                                batch=data['batch'],
                                dataset_idx=torch.full_like(data["z"], dtype=torch.long,
                                                            fill_value=0))
        predictions = model.forward(model_input_dict)['energy'].detach().numpy().tolist()
        final_energies += predictions

In [61]:
# torch.save(final_energies, f'Data/{project_name}/Pred.pt')
torch.save(final_energies, f'Data/Kwon_Chiral2/Kwon_Chiral2_pred.pt')

In [None]:
list(final_energies.keys())[1000:]

In [60]:
final_energies

{'ts1_00000_008_009_0_0': -386062.40625,
 'ts2_00000_008_009_0_0': -386063.75,
 'ts1_00000_008_009_0_1': -386060.96875,
 'ts2_00000_008_009_0_1': -386058.375,
 'ts1_00000_008_009_1_0': -386061.96875,
 'ts2_00000_008_009_1_0': -386063.5625,
 'ts1_00000_008_009_1_1': -386062.5,
 'ts2_00000_008_009_1_1': -386061.15625,
 'ts2_00000_009_008_0_0': -386060.625,
 'ts2_00000_009_008_0_1': -386063.78125,
 'ts2_00000_009_008_1_0': -386059.4375,
 'ts2_00000_009_008_1_1': -386063.34375,
 'ts1_00001_008_009_0_0': -386063.0625,
 'ts2_00001_008_009_0_0': -386063.6875,
 'ts1_00001_008_009_0_1': -386063.0625,
 'ts2_00001_008_009_0_1': -386060.5,
 'ts1_00001_008_009_1_0': -386060.15625,
 'ts2_00001_008_009_1_0': -386063.75,
 'ts1_00001_008_009_1_1': -386061.9375,
 'ts2_00001_008_009_1_1': -386059.125,
 'ts2_00001_009_008_0_0': -386062.1875,
 'ts2_00001_009_008_0_1': -386062.84375,
 'ts2_00001_009_008_1_0': -386061.25,
 'ts2_00001_009_008_1_1': -386064.0625}

In [1]:
# XTB Test

In [1]:
from xtb.interface import Molecule, Param, XTB, Constraint

ModuleNotFoundError: No module named 'xtb'