In [1]:
import os
import yaml
import torch
from easydict import EasyDict
from rdkit.Chem import AllChem

%load_ext autoreload
%autoreload 2

os.chdir('..')
os.getcwd()

from models.epsnet import *
from utils.datasets import *
from utils.transforms import *
from utils.misc import *
from utils.chem import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 'cuda:0'
model = DualEncoderEpsNetwork().to(device)
model.load_state_dict(torch.load("./param/cross_align.pt"))

<All keys matched successfully>

In [3]:
def bond_type_to_int(bond):
    bond_type = bond.GetBondType()
    if bond_type == Chem.rdchem.BondType.SINGLE:
        return 1
    elif bond_type == Chem.rdchem.BondType.DOUBLE:
        return 2
    elif bond_type == Chem.rdchem.BondType.TRIPLE:
        return 3
    elif bond_type == Chem.rdchem.BondType.AROMATIC:
        return 12
    else:
        assert "Bond type error" # 기본적으로 단일 결합 처리

def mol_to_graph_data_obj(mol):
    # 노드 정보 (원자)
    atom_features = []
    for atom in mol.GetAtoms():
        atom_features.append(atom.GetAtomicNum())  # 원자의 원자 번호 사용
    x = torch.tensor(atom_features, dtype=torch.int)

    # 엣지 정보 (결합)
    edges = []
    bond_types = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        bond_type = bond_type_to_int(bond)

        # 양방향 엣지와 해당 결합 타입 추가
        edges.append((i, j))
        bond_types.append(bond_type)
        edges.append((j, i))
        bond_types.append(bond_type)

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(bond_types, dtype=torch.float).view(-1)

    conf = mol.GetConformer()
    coordinates = []
    for atom in mol.GetAtoms():
        pos = conf.GetAtomPosition(atom.GetIdx())
        coordinates.append([pos.x, pos.y, pos.z])
    pos = torch.tensor(coordinates, dtype=torch.float)


    # PyTorch Geometric의 Data 객체 생성
    data = Data(atom_type=x, edge_index=edge_index, edge_type=edge_attr, pos=pos)
    
    return data

In [4]:
mol = Chem.SDMolSupplier(f"./mol/ligand_init.sdf", sanitize=False)[0]

In [7]:
n_steps = 200 

batch1 = mol_to_graph_data_obj(mol)
batch2 = copy.deepcopy(batch1)
mean_pos = batch2['pos'].mean(0)
batch2['pos'] = batch2['pos'] - mean_pos
ligand_batch = Batch.from_data_list([batch1]).cuda()
template_batch = Batch.from_data_list([batch2]).cuda()
model.eval()

ligand_batch.pos.normal_()
ligand_batch.pos = ligand_batch.pos/4

pos_gen, pos_gen_traj = model.langevin_dynamics_sample_ddim_g(
    ligand_batch=ligand_batch,
    template_batch=template_batch,
    extend_order=True, # Done in transforms.
    n_steps=n_steps,
    step_lr=1e-6,
    w_global=0.9,
    global_start_sigma=1.0,
    clip=1000.0,
    clip_local=None,
    sampling_type='generalized',
    #sampling_type='generalized',
    eta=1.0
)
template_batch['pos'] = template_batch['pos'] + mean_pos.cuda()
pos_gen = pos_gen.cpu() + mean_pos.cpu()


  merge_batch.graph_idx = torch.tensor(merge_batch.batch)
  bgraph_adj = torch.sparse.LongTensor(
sample: 199it [00:03, 63.23it/s]


In [8]:
init_mol = set_rdmol_positions(mol, template_batch[0]['pos'])
gen_mol = set_rdmol_positions(mol, pos_gen)

In [9]:
import py3Dmol

In [10]:
def show(mol):
  mblock = Chem.MolToMolBlock(mol)
  view = py3Dmol.view(width=500, height=500)
  view.addModel(mblock, 'mol')
  view.setStyle({'stick':{}})
  view.zoomTo()
  view.show()

In [11]:
mblock = Chem.MolToMolBlock(init_mol)
view = py3Dmol.view(width=500, height=500)
view.addModel(mblock, 'mol')
mblock = Chem.MolToMolBlock(gen_mol)
view.addModel(mblock, 'mol')
view.setStyle({'model':0},{'stick':{'color': 'orange'}})
view.setStyle({'model':1},{'stick':{'color': 'green'}})
view.zoomTo()
view.show()