Data Preprocess

In [1]:
from data import prepare_data
train_dataset_path,val_dataset_path = prepare_data(
    file_path = './example/ads/raw_data/adsorption_energy.csv',
    save_path = './example/ads/data',
    _col_name='smiles',
    test_size=0.2
    )

2024-12-17 09:06:24.192563: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-17 09:06:24.230452: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-17 09:06:24.230496: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-17 09:06:24.231697: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-17 09:06:24.238882: I tensorflow/core/platform/cpu_feature_guar

fused_multi_tensor is not installed corrected
fused_layer_norm is not installed corrected
fused_softmax is not installed corrected
Preparing data...
Converting SMILES to XYZ...
Data already exists, skipping...
Converting data to LMDB...
Done


Train

In [5]:
# dataset & hyparams
from utils import LMDBDataset,batch_collate_fn
from torch.utils.data import DataLoader, Dataset
num_epochs = 100
batch_size=128
learning_rate = 1e-4
patience = 20

train_dataset = LMDBDataset(train_dataset_path)
val_dataset = LMDBDataset(val_dataset_path)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=batch_collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,collate_fn=batch_collate_fn)


In [18]:
# model
import torch
import torch.nn as nn
from model import *
from data import *

mol2input = Mol2Input()
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
fitting_model = FittingTestNet(hidden_dim=40,output_dim=train_dataset[0][2].shape[0])
unimol_model = UniMolModel(output_dim=1, data_type='molecule', remove_hs=False)
fitting_model.to(device)
fitting_model.train()
unimol_model.to(device)
unimol_model.train()
# optimizer = torch.optim.Adam(list(fitting_model.parameters()) + list(unimol_model.parameters()), lr=learning_rate)
optimizer = torch.optim.Adam(list(fitting_model.parameters()), lr=learning_rate)
criterion = nn.MSELoss()

In [None]:
# train
train_loss = []
val_loss = []
MIN_LOSS  = 1E4
_patience = 0
save_path = '/vepfs/fs_users/ycjin/Delta-ML-Framework/tasks/demo/example/ads/data',
for epoch in range(num_epochs):
    for coord,atype,target in train_loader:
        input_dict = mol2input.coord2unimol_inputs(coord,atype)
        for k in input_dict.keys(): input_dict[k] = input_dict[k].to(device)
        cls_reprs = unimol_model(return_repr=True,**input_dict)['cls_repr']
        pred = fitting_model(cls_reprs)
        loss = criterion(pred,target.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())

    with torch.no_grad():
        for coord,atype,target in val_loader:
            input_dict = mol2input.coord2unimol_inputs(coord,atype)
            for k in input_dict.keys(): input_dict[k] = input_dict[k].to(device)
            cls_reprs = unimol_model(return_repr=True,**input_dict)['cls_repr']
            pred = fitting_model(cls_reprs)
            loss = criterion(pred,target.to(device))
            val_loss.append(loss.item())

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {np.mean(train_loss):.4f}, Val Loss: {np.mean(val_loss):.4f}')
    if np.mean(val_loss) < MIN_LOSS:
        MIN_LOSS = np.mean(val_loss)
        pth_save_path = './example/ads/pth/'
        if not os.path.exists(pth_save_path):
            os.makedirs(pth_save_path)
        torch.save(unimol_model.state_dict(), pth_save_path +'model.pth')
        torch.save(fitting_model.state_dict(), pth_save_path + 'fit.pth')
        _patience = 0
    else:
        _patience+=1
        if _patience > patience:
            print('Early stopping')
            break


Epoch 1/100, Train Loss: 0.7710, Val Loss: 0.4132
Epoch 2/100, Train Loss: 0.5691, Val Loss: 0.3820
Epoch 3/100, Train Loss: 0.4914, Val Loss: 0.3664
Epoch 4/100, Train Loss: 0.4495, Val Loss: 0.3537
Epoch 5/100, Train Loss: 0.4209, Val Loss: 0.3446
Epoch 6/100, Train Loss: 0.4005, Val Loss: 0.3413
Epoch 7/100, Train Loss: 0.3853, Val Loss: 0.3351
Epoch 8/100, Train Loss: 0.3733, Val Loss: 0.3305
Epoch 9/100, Train Loss: 0.3633, Val Loss: 0.3272
Epoch 10/100, Train Loss: 0.3553, Val Loss: 0.3235
Epoch 11/100, Train Loss: 0.3486, Val Loss: 0.3209
Epoch 12/100, Train Loss: 0.3428, Val Loss: 0.3192
Epoch 13/100, Train Loss: 0.3379, Val Loss: 0.3163
Epoch 14/100, Train Loss: 0.3335, Val Loss: 0.3141
Epoch 15/100, Train Loss: 0.3296, Val Loss: 0.3116
Epoch 16/100, Train Loss: 0.3259, Val Loss: 0.3097
Epoch 17/100, Train Loss: 0.3226, Val Loss: 0.3079
Epoch 18/100, Train Loss: 0.3194, Val Loss: 0.3060
Epoch 19/100, Train Loss: 0.3166, Val Loss: 0.3046
Epoch 20/100, Train Loss: 0.3139, Val Lo

inference

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from model import *
from data import *
from utils import *

pth_save_path = './example/ads/pth/'

mol2input = Mol2Input()
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
fitting_model = FittingNet(output_dim=1)
unimol_model = UniMolModel(output_dim=1, data_type='molecule', remove_hs=False)
fitting_model.load_state_dict(torch.load(pth_save_path +'fit.pth'))
unimol_model.load_state_dict(torch.load(pth_save_path + 'model.pth'))
fitting_model.to(device)
fitting_model.eval()
unimol_model.to(device)
unimol_model.eval()
print('Load model successfully!')

In [None]:
from ase.io import read
input_file = './example/ads/raw_data/0.xyz' # input molecule file
atom = read(input_file)
coord = [torch.tensor(atom.get_positions())]
atype = [np.array(atom.get_chemical_symbols())]

input_dict = mol2input.coord2unimol_inputs(coord,atype)
for k in input_dict.keys(): input_dict[k] = input_dict[k].to(device)
cls_reprs = unimol_model(return_repr=True,**input_dict)['cls_repr']
pred = fitting_model(cls_reprs)
pred