In [1]:
import os
import sys

os.environ["CUDA_VISIBLE_DEVICES"] = "7"
sys.path.append('../') 

In [None]:
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np

import random
from demo.ConfSeq import get_ConfSeq_pair_from_mol,get_mol_from_ConfSeq_pair,randomize_mol
import copy
from tqdm import tqdm
from rdkit.Chem import rdmolops, rdchem
from tqdm.contrib.concurrent import process_map  

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from transformers import BartForConditionalGeneration, BartConfig
from collections import OrderedDict
import torch
from torch import nn

import numpy as np
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt

from rdkit.ML.Scoring import Scoring
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
vocab = [chr(i) for i in range(33, 127)] 

for i in range(-180,180):
    vocab.append('<'+str(i)+'>')

vocab.append('<mask>')
vocab.append('<unk>')
vocab.append('<sos>')
vocab.append('<eos>')
vocab.append('<pad>')


config = BartConfig()

config.pad_token_id = vocab.index('<pad>')
config.eos_token_id = vocab.index('<eos>')
config.sos_token_id = vocab.index('<sos>')
config.forced_eos_token_id = None
config.encoder_layers = 6
config.encoder_attention_heads = 8
config.decoder_layers = 0
config.decoder_attention_heads = 0
config.d_model = 256
# config.share_embeddings = True
config.vocab_size = len(vocab)


vocab_dict = {char: idx for idx, char in enumerate(vocab)}
bart = BartForConditionalGeneration(config = config )


class CustomBartEncoder(nn.Module):
    def __init__(self, bart):
        super().__init__()
    
        # 加载 BART 模型
        self.bart_model = bart 
        
    def forward(self, input_ids, attention_mask=None):
        # 获取编码器输出
        outputs = self.bart_model.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state

# 示例使用
bart_encoder_model  = CustomBartEncoder(bart)

In [4]:
def rm_invalid_chirality(mol):
    mol = copy.deepcopy(mol)
    """
    找出分子中同时出现在三个环中的原子。
    
    参数:
        mol: RDKit 分子对象
    返回:
        List[int]: 同时出现在三个环中的原子的索引列表
    """
    # 获取分子的所有环（SSSR：最小集的简单环）
    rings = rdmolops.GetSymmSSSR(mol)

    # 创建一个字典，记录每个原子出现在多少个环中
    atom_in_rings_count = {}

    # 遍历所有环，统计每个原子出现的次数
    for ring in rings:
        for atom_idx in ring:
            if atom_idx not in atom_in_rings_count:
                atom_in_rings_count[atom_idx] = 0
            atom_in_rings_count[atom_idx] += 1

    # 找出那些同时出现在三个环中的原子
    atoms_in_3_rings = [atom for atom, count in atom_in_rings_count.items() if count == 3]

    for atom_idx in atoms_in_3_rings:
        atom = mol.GetAtomWithIdx(atom_idx)
        atom.SetChiralTag(rdchem.ChiralType.CHI_UNSPECIFIED)

    return mol

In [5]:
def get_ConfSeq(query_mol):

    if query_mol != None:
        try:
            query_mol = rm_invalid_chirality(query_mol)
            query_mol = randomize_mol(query_mol)
            Chem.MolToSmiles(query_mol)    
            query_mol = Chem.RenumberAtoms(query_mol, eval(query_mol.GetProp('_smilesAtomOutputOrder'))) 
            Chem.MolToSmiles(query_mol,canonical = False)
            query_mol = Chem.RenumberAtoms(query_mol, eval(query_mol.GetProp('_smilesAtomOutputOrder'))) 
            in_smiles,TD_smiles = get_ConfSeq_pair_from_mol(query_mol)
            TD_smiles = TD_smiles.replace('<180>','<-180>')
            
        except:
            in_smiles,TD_smiles = '',''            
    else:
        in_smiles,TD_smiles = '',''
        
    return TD_smiles

In [6]:
def get_embedding_for_seq(i):
    int = [vocab_dict[i] for i in i.split(' ')] 
    int = torch.Tensor([int]).long()
    embed = bart_encoder_model(int).mean(dim=1).tolist()[0]
    return embed

In [7]:
class TensorDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.Tensor(self.data[idx]).long()

def collate_fn(batch):
    data = batch  # 解压每个样本中的数据和标签
    padded_data = pad_sequence(data, batch_first=True, padding_value = vocab.index('<pad>'))  # 对数据进行填充
    return padded_data


def mean_pooling(last_hidden_state, attention_mask):
    # 对每个样本进行池化，忽略 pad 的位置
    # 将 attention_mask 转换为 float 类型，并进行扩展
    attention_mask = attention_mask.unsqueeze(-1).expand(last_hidden_state.size())
    # 计算有效的特征总和和有效特征的数量
    sum_embeddings = (last_hidden_state * attention_mask).sum(dim=1)
    sum_mask = attention_mask.sum(dim=1)
    # 计算平均池化，避免除以零
    pooled_output = sum_embeddings / (sum_mask + 1e-8)  # 加小常数以避免除零
    return pooled_output

In [8]:
device = torch.device('cuda')

In [9]:
checkpoint = torch.load('./checkpoints/model_epoch_1.pth', map_location='cpu')  # 使用适当的设备    

new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    if k.startswith('module.'):
        new_state_dict[k[7:]] = v
    else:
        new_state_dict[k] = v

bart_encoder_model.load_state_dict(new_state_dict)
bart_encoder_model.eval()


lis = os.listdir('./data/PCBA/AVE_unbiased/')
lis = [i for i in lis if i[-4:] != '.txt']

for name in lis[:]:

    os.makedirs('./data/PCBA/AVE_unbiased_result/{}/'.format(name), exist_ok=True)
    
    file_lis = os.listdir('./data/PCBA/AVE_unbiased/{}/'.format(name))
    file_lis = [i for i in file_lis if 'ligand_.sdf' in i]

    suppl = Chem.SDMolSupplier('./data/PCBA/AVE_unbiased/{}/actives.sdf'.format(name))
    # 将有效的分子（非 None）保存在列表中
    actives_mols = [mol for mol in suppl if mol is not None]
    actives_mols_atom_num = [i.GetNumAtoms() for i in actives_mols]
    
    suppl = Chem.SDMolSupplier('./data/PCBA/AVE_unbiased/{}/inactives.sdf'.format(name))
    # 将有效的分子（非 None）保存在列表中
    decoys_mols = [mol for mol in suppl if mol is not None]
    decoys_mols_atom_num = [i.GetNumAtoms() for i in decoys_mols]


    actives_seqs = process_map(get_ConfSeq, tqdm(actives_mols), max_workers = 32)
    print('actives_seqs length:',len(actives_seqs))
    
    decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
    print('decoys_seqs length:',len(decoys_seqs))

    with open('./data/PCBA/AVE_unbiased_result/{}/actives_seqs.txt'.format(name),'w+') as f:
        f.write('\n'.join(actives_seqs))

    with open('./data/PCBA/AVE_unbiased_result/{}/decoys_seqs.txt'.format(name),'w+') as f:
        f.write('\n'.join(decoys_seqs))


    with open('./data/PCBA/AVE_unbiased/{}/actives_seqs.txt'.format(name),'r') as f:
        content = f.read()
    active_seqs = content.split('\n')   
    active_seqs = [i if i != '' else '*' for i in actives_seqs]    
    
    with open('./data/PCBA/AVE_unbiased/{}/decoys_seqs.txt'.format(name),'r') as f:
        content = f.read()
    decoys_seqs = content.split('\n') 
    decoys_seqs = [i if i != '' else '*' for i in decoys_seqs]


    active_ints = [[vocab_dict[i] for i in seq.split(' ')] for seq in actives_seqs]
    active_dataset = TensorDataset(active_ints)
    active_dataloader = DataLoader(active_dataset, batch_size=128,collate_fn=collate_fn)

    bart_encoder_model.to(device)
    
    active_embeds = []
    
    with torch.no_grad():  # 取消梯度计算
    
        for input_ids in active_dataloader:
            input_ids = input_ids.to(device)
            attention_mask = (input_ids != vocab.index('<pad>')).long() # 创建 attention mask
            # 编码
            output = bart_encoder_model(input_ids,attention_mask)
            pooled_output = mean_pooling(output, attention_mask)
            active_embeds.append(pooled_output)
    
    active_embeds = torch.cat(tuple(active_embeds), dim=0).to('cpu')

    inactive_ints = [[vocab_dict[i] for i in seq.split(' ')] for seq in decoys_seqs]
    inactive_dataset = TensorDataset(inactive_ints)
    inactive_dataloader = DataLoader(inactive_dataset, batch_size=128,collate_fn=collate_fn)

    
    inactive_embeds = []
    
    with torch.no_grad():  # 取消梯度计算
    
        for input_ids in inactive_dataloader:
            input_ids = input_ids.to(device)
            attention_mask = (input_ids != vocab.index('<pad>')).long() # 创建 attention mask
            # 编码
            output = bart_encoder_model(input_ids,attention_mask)
            pooled_output = mean_pooling(output, attention_mask)
            inactive_embeds.append(pooled_output)
    
    inactive_embeds = torch.cat(tuple(inactive_embeds), dim=0).to('cpu')

    
    fingerprints = np.vstack([active_embeds,inactive_embeds])
    labels = np.array([1] * len(active_embeds) + [0] * len(inactive_embeds))  # 活性为1，非活性为0   

    mols_atom_num = np.array(actives_mols_atom_num + decoys_mols_atom_num)
    
    for file in file_lis:    
        print(file)

        query_mol = Chem.MolFromMolFile('./data/PCBA/AVE_unbiased/{}/{}'.format(name,file))
        query_mol_atom_num = query_mol.GetNumAtoms()
        
        bart_encoder_model.to(torch.device('cpu'))
        query_embeds = get_embedding_for_seq(get_ConfSeq(query_mol))
        query_embeds = torch.Tensor(query_embeds)
        euclidean_distance = torch.sqrt(torch.sum((query_embeds - fingerprints ) ** 2,dim = -1))
        similarity_score = 1 / (1 + euclidean_distance)

        factor = 2 * mols_atom_num/ (query_mol_atom_num + mols_atom_num)
        similarity_score = similarity_score * factor
        
        fpr, tpr, thresholds = roc_curve(labels.tolist(), similarity_score.tolist())
        
        # 计算 AUC
        roc_auc = auc(fpr, tpr)

        similarity_score_np = similarity_score.numpy()
        sort_idx = np.argsort(-similarity_score_np)
        sorted_scores = similarity_score_np[sort_idx]
        sorted_labels = labels[sort_idx].astype(bool)
        score_data = np.column_stack((sorted_scores, sorted_labels))

        bedroc_rdkit = Scoring.CalcBEDROC(score_data, col=1, alpha=80.5)
        ef_1 = Scoring.CalcEnrichment(score_data, col=1, fractions=[0.01])[0]
        ef_5 = Scoring.CalcEnrichment(score_data, col=1, fractions=[0.05])[0]

        # 构造结果行
        row = {
            "Protein": name,
            "Ligand_File": file,
            "ROC_AUC": round(roc_auc, 4),
            "BEDROC": round(bedroc_rdkit, 4),
            "EF5%": round(ef_5, 4),
            "EF1%": round(ef_1, 4)
        }

        # 写入 CSV 文件：如果文件不存在，则写入表头，否则追加行
        csv_path = f'./PCBA_evaluation_result.csv'
        if not os.path.exists(csv_path):
            df = pd.DataFrame([row])
            df.to_csv(csv_path, index=False)
        else:
            df = pd.DataFrame([row])
            df.to_csv(csv_path, mode='a', index=False, header=False)

100%|█████████████████████████████████████████████| 39/39 [00:02<00:00, 13.67it/s]
100%|████████████████████████████████████████████| 39/39 [00:00<00:00, 801.68it/s]


actives_seqs length: 39


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|██████████████████████████████████| 358579/358579 [00:14<00:00, 24896.46it/s]
100%|███████████████████████████████████| 358579/358579 [02:17<00:00, 2603.21it/s]


decoys_seqs length: 358579
5lge_ligand_.sdf
4xs3_ligand_.sdf
5de1_ligand_.sdf
5l57_ligand_.sdf
4i3l_ligand_.sdf
4umx_ligand_.sdf
4xrx_ligand_.sdf
6adg_ligand_.sdf
6b0z_ligand_.sdf
5sun_ligand_.sdf
5l58_ligand_.sdf
4i3k_ligand_.sdf
5svf_ligand_.sdf
5tqh_ligand_.sdf


100%|█████████████████████████████████████████████| 24/24 [00:03<00:00,  6.45it/s]
100%|████████████████████████████████████████████| 24/24 [00:00<00:00, 457.54it/s]


actives_seqs length: 24


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|████████████████████████████████████████| 4068/4068 [00:04<00:00, 880.37it/s]
100%|███████████████████████████████████████| 4068/4068 [00:01<00:00, 2796.49it/s]


decoys_seqs length: 4068
4prg_ligand_.sdf
2i4j_ligand_.sdf
1zgy_ligand_.sdf
4fgy_ligand_.sdf
2q5s_ligand_.sdf
2p4y_ligand_.sdf
3hod_ligand_.sdf
5y2t_ligand_.sdf
3r8a_ligand_.sdf
4ci5_ligand_.sdf
2yfe_ligand_.sdf
5two_ligand_.sdf
5tto_ligand_.sdf
3b1m_ligand_.sdf
5z5s_ligand_.sdf


100%|█████████████████████████████████████████████| 13/13 [00:03<00:00,  3.75it/s]
100%|████████████████████████████████████████████| 13/13 [00:00<00:00, 243.84it/s]


actives_seqs length: 13


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|███████████████████████████████████████| 4376/4376 [00:03<00:00, 1202.16it/s]
100%|███████████████████████████████████████| 4376/4376 [00:01<00:00, 2921.67it/s]


decoys_seqs length: 4376
5du5_ligand_.sdf
2b1z_ligand_.sdf
4ivw_ligand_.sdf
2b1v_ligand_.sdf
5e1c_ligand_.sdf
2q70_ligand_.sdf
5due_ligand_.sdf
5dzi_ligand_.sdf
4pps_ligand_.sdf
2qr9_ligand_.sdf
5drj_ligand_.sdf
2qse_ligand_.sdf
2p15_ligand_.sdf
2qzo_ligand_.sdf
1l2i_ligand_.sdf


100%|█████████████████████████████████████████████| 24/24 [00:03<00:00,  6.95it/s]
100%|████████████████████████████████████████████| 24/24 [00:00<00:00, 466.27it/s]


actives_seqs length: 24


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|██████████████████████████████████| 269345/269345 [00:12<00:00, 22443.12it/s]
100%|███████████████████████████████████| 269345/269345 [01:33<00:00, 2869.42it/s]


decoys_seqs length: 269345
6b73_ligand_.sdf


100%|█████████████████████████████████████████████| 17/17 [00:04<00:00,  3.65it/s]
100%|████████████████████████████████████████████| 17/17 [00:00<00:00, 399.56it/s]


actives_seqs length: 17


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|██████████████████████████████████| 311600/311600 [00:15<00:00, 20187.91it/s]
100%|███████████████████████████████████| 311600/311600 [01:56<00:00, 2676.43it/s]


decoys_seqs length: 311600
4ldl_ligand_.sdf
3sn6_ligand_.sdf
4qkx_ligand_.sdf
3pds_ligand_.sdf
4lde_ligand_.sdf
6mxt_ligand_.sdf
3p0g_ligand_.sdf
4ldo_ligand_.sdf


100%|██████████████████████████████████████████| 546/546 [00:04<00:00, 115.04it/s]
100%|█████████████████████████████████████████| 546/546 [00:00<00:00, 1869.34it/s]


actives_seqs length: 546


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|██████████████████████████████████| 244552/244552 [00:14<00:00, 17008.79it/s]
100%|███████████████████████████████████| 244552/244552 [01:31<00:00, 2684.90it/s]


decoys_seqs length: 244552
5x1v_ligand_.sdf
3h6o_ligand_.sdf
3u2z_ligand_.sdf
3me3_ligand_.sdf
3gqy_ligand_.sdf
4g1n_ligand_.sdf
4jpg_ligand_.sdf
5x1w_ligand_.sdf
3gr4_ligand_.sdf


100%|██████████████████████████████████████████| 653/653 [00:04<00:00, 135.06it/s]
100%|█████████████████████████████████████████| 653/653 [00:00<00:00, 2210.76it/s]


actives_seqs length: 653


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|██████████████████████████████████| 262483/262483 [00:15<00:00, 17336.18it/s]
100%|███████████████████████████████████| 262483/262483 [01:40<00:00, 2609.19it/s]


decoys_seqs length: 262483
3a2j_ligand_.sdf
3a2i_ligand_.sdf


  actives_seqs = process_map(get_ConfSeq, tqdm(actives_mols), max_workers = 32)
100%|████████████████████████████████████████| 5362/5362 [00:05<00:00, 960.42it/s]
100%|███████████████████████████████████████| 5362/5362 [00:02<00:00, 2264.59it/s]


actives_seqs length: 5362


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|███████████████████████████████████| 101771/101771 [00:10<00:00, 9764.91it/s]
100%|███████████████████████████████████| 101771/101771 [00:37<00:00, 2678.53it/s]


decoys_seqs length: 101771
5l2m_ligand_.sdf
5l2o_ligand_.sdf
5l2n_ligand_.sdf
4x4l_ligand_.sdf
4wpn_ligand_.sdf
5ac2_ligand_.sdf
5tei_ligand_.sdf
4wp7_ligand_.sdf


100%|█████████████████████████████████████████████| 97/97 [00:05<00:00, 16.25it/s]
100%|████████████████████████████████████████████| 97/97 [00:00<00:00, 751.64it/s]


actives_seqs length: 97


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|█████████████████████████████████████| 32952/32952 [00:07<00:00, 4317.21it/s]
100%|█████████████████████████████████████| 32952/32952 [00:13<00:00, 2467.01it/s]


decoys_seqs length: 32952
1fap_ligand_.sdf
1nsg_ligand_.sdf
3fap_ligand_.sdf
5gpg_ligand_.sdf
4fap_ligand_.sdf
4jt5_ligand_.sdf
4drh_ligand_.sdf
4drj_ligand_.sdf
4jsx_ligand_.sdf
4dri_ligand_.sdf
2fap_ligand_.sdf


100%|███████████████████████████████████████████| 308/308 [00:05<00:00, 57.73it/s]
100%|█████████████████████████████████████████| 308/308 [00:00<00:00, 1635.53it/s]


actives_seqs length: 308


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|█████████████████████████████████████| 61461/61461 [00:08<00:00, 7613.43it/s]
100%|█████████████████████████████████████| 61461/61461 [00:20<00:00, 2935.47it/s]


decoys_seqs length: 61461
6g9h_ligand_.sdf
2ojg_ligand_.sdf
4qta_ligand_.sdf
5v62_ligand_.sdf
4qte_ligand_.sdf
4qp4_ligand_.sdf
1pme_ligand_.sdf
4xj0_ligand_.sdf
5buj_ligand_.sdf
3w55_ligand_.sdf
3sa0_ligand_.sdf
4qp9_ligand_.sdf
4zzn_ligand_.sdf
5ax3_ligand_.sdf
4qp3_ligand_.sdf


100%|███████████████████████████████████████████| 163/163 [00:04<00:00, 33.86it/s]
100%|█████████████████████████████████████████| 163/163 [00:00<00:00, 1292.38it/s]


actives_seqs length: 163


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|██████████████████████████████████| 291039/291039 [00:14<00:00, 20669.76it/s]
100%|███████████████████████████████████| 291039/291039 [01:38<00:00, 2966.22it/s]


decoys_seqs length: 291039
2xwd_ligand_.sdf
3ril_ligand_.sdf
3rik_ligand_.sdf
2xwe_ligand_.sdf
2v3d_ligand_.sdf
2v3e_ligand_.sdf


100%|███████████████████████████████████████████| 360/360 [00:05<00:00, 71.57it/s]
100%|█████████████████████████████████████████| 360/360 [00:00<00:00, 1698.73it/s]


actives_seqs length: 360


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|██████████████████████████████████| 350540/350540 [00:18<00:00, 18756.38it/s]
100%|███████████████████████████████████| 350540/350540 [02:05<00:00, 2796.34it/s]


decoys_seqs length: 350540
5fv7_ligand_.sdf


100%|█████████████████████████████████████████████| 64/64 [00:06<00:00, 10.02it/s]
100%|███████████████████████████████████████████| 64/64 [00:00<00:00, 1057.72it/s]


actives_seqs length: 64


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|████████████████████████████████████████| 3344/3344 [00:06<00:00, 502.81it/s]
100%|███████████████████████████████████████| 3344/3344 [00:01<00:00, 3016.72it/s]


decoys_seqs length: 3344
2vuk_ligand_.sdf
4agq_ligand_.sdf
4ago_ligand_.sdf
5g4o_ligand_.sdf
3zme_ligand_.sdf
5o1i_ligand_.sdf


100%|█████████████████████████████████████████████| 88/88 [00:06<00:00, 14.43it/s]
100%|███████████████████████████████████████████| 88/88 [00:00<00:00, 1106.01it/s]


actives_seqs length: 88


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|████████████████████████████████████████| 3818/3818 [00:06<00:00, 591.39it/s]
100%|███████████████████████████████████████| 3818/3818 [00:01<00:00, 3115.95it/s]


decoys_seqs length: 3818
3dt3_ligand_.sdf
6b0f_ligand_.sdf
6chw_ligand_.sdf
2iog_ligand_.sdf
2r6w_ligand_.sdf
5fqv_ligand_.sdf
2pog_ligand_.sdf
1xp1_ligand_.sdf
2ayr_ligand_.sdf
1xqc_ligand_.sdf
2iok_ligand_.sdf
5t92_ligand_.sdf
5ufx_ligand_.sdf
5aau_ligand_.sdf
2ouz_ligand_.sdf


100%|███████████████████████████████████████████| 194/194 [00:05<00:00, 34.79it/s]
100%|█████████████████████████████████████████| 194/194 [00:00<00:00, 1108.75it/s]


actives_seqs length: 194


  decoys_seqs = process_map(get_ConfSeq, tqdm(decoys_mols), max_workers = 32)
100%|██████████████████████████████████| 342518/342518 [00:18<00:00, 18526.40it/s]
100%|███████████████████████████████████| 342518/342518 [01:57<00:00, 2926.77it/s]


decoys_seqs length: 342518
5h84_ligand_.sdf
5h86_ligand_.sdf
5mlj_ligand_.sdf
