inference

In [6]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from model import *
from data import *
from utils import *

pth_save_path = './example/diff/pth/'

mol2input = Mol2Input()
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
fitting_model = FittingNet(output_dim=1)
unimol_model = UniMolModel(output_dim=1, data_type='molecule', remove_hs=False)
fitting_model.load_state_dict(torch.load(pth_save_path +'atomic_fit_nh.pth'))
unimol_model.load_state_dict(torch.load(pth_save_path + 'atomic_model_nh.pth'))
fitting_model.to(device)
fitting_model.eval()
unimol_model.to(device)
unimol_model.eval()
print('Load model successfully!')

2024-11-25 04:02:47 | model/unimol.py | 114 | INFO | Uni-Mol(QSAR) | Loading pretrained weights from /vepfs/fs_users/ycjin/ads_predict_tools/weights/mol_pre_all_h_220816.pt


Load model successfully!


单项预测

In [7]:
from ase import Atoms
from ase.io import read,write
import py3Dmol
from ase.io import read

input_file = './example/pub25/data/xyz/400.xyz' # input molecule file
atom = read(input_file)

# 去除H原子
sym = np.array(atom.get_chemical_symbols())
coord = [torch.tensor(atom.get_positions())[sym != 'H']]
atype = [np.array(atom.get_chemical_symbols())[sym != 'H']]
atom = atom[sym != 'H']

# 预测
input_dict = mol2input.coord2unimol_inputs(coord,atype)
for k in input_dict.keys(): input_dict[k] = input_dict[k].to(device)
atomic_reprs = unimol_model(return_repr=True,**input_dict)['atomic_reprs']
pred = []
for repr in atomic_reprs:
    atomic_p = fitting_model(repr)
    p = torch.sum(fitting_model(repr))

print(atomic_p)# 原子贡献
print(p)# 总能量


# 可视化
value = np.array(atomic_p.detach().cpu()).reshape(-1)
value = (value-np.min(value))/(np.max(value)-np.min(value)) * 200+50
value = np.int16(value)


setting = {
    'H':[0.3],
    'C':[0.3],
    'O':[0.3],
    'N':[0.3],
    'S':[0.3],
    'F':[0.3],
    'Cl':[0.3],
    'Br':[0.3],
    'I':[0.3]
}

def self_hex(n):
    return hex(n)[2:].zfill(2)

write('md_n.xyz', atom)

# 贡献图
view = py3Dmol.view(width=300, height=300)
view.addModel(open('md_n.xyz').read(), format='xyz')
for i,type in enumerate(atype[0]):
    view.setStyle({'index':i}, {'sphere': {'scale':setting[type][0],'color':'#'+self_hex(value[i])+self_hex(value[i])+self_hex(value[i])}})
view.zoomTo(animate=True)
view.show()

# 分子图
view_2 = py3Dmol.view(width=300, height=300)
view_2.addModel(open('md_n.xyz').read(), format='xyz')
view.setStyle({'sphere': {'scale': 0.3}})
view.zoomTo(animate=True)
view.show()

tensor([[0.1258],
        [0.1808],
        [0.1754],
        [0.1441],
        [0.1706],
        [0.2468],
        [0.1925],
        [0.1744],
        [0.1812],
        [0.1226],
        [0.2878],
        [0.2340],
        [0.1205],
        [0.2017],
        [0.1549],
        [0.1397],
        [0.1510],
        [0.1292],
        [0.1402],
        [0.1180],
        [0.1076],
        [0.1383],
        [0.1149],
        [0.1587],
        [0.2012]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor(4.1119, device='cuda:0', grad_fn=<SumBackward0>)
