In [1]:
import os
os.getcwd()

'/home/a1356913256/python/molecular_generation/CPL-Diff/scripts'

In [2]:
import torch
import random
import numpy as np
import pandas as pd
import math
import torch.nn.functional as F
import configparser
import concurrent.futures
from Bio import SeqIO, Align
from tqdm import tqdm
from CPLDiff.models.Denoiser import Denoiser
from CPLDiff.utils.CPLDiffDataset import XDataset
from CPLDiff.utils.utils import set_seed
from CPLDiff.utils.utils import extract
from statistics import mean, stdev
from transformers import AutoTokenizer, AutoModelForMaskedLM
from collections import defaultdict, Counter
from torch import device
from torch.cuda import is_available
from modlamp.descriptors import GlobalDescriptor
from esm import FastaBatchedDataset

In [3]:
"""修改此项来决定要采样哪种多肽，以及引导强度"""
sample_type = 'amp'
assert sample_type.lower() in ['amp', 'afp', 'avp']
cfs = 1.5
n = 1
sample_num = 1000

In [4]:
# 获取可以使用的硬件资源
device = device("cuda:0" if is_available() else "cpu")
device

device(type='cuda', index=0)

In [5]:
"""读取配置文件"""
conf = configparser.ConfigParser()
conf.read('../config.ini')
conf_dict = dict(conf.items('CPLDiff_conf'))
set_seed(int(conf_dict['seed']))

In [6]:
labels = {'antimicrobial': 0, 'antifungal': 1, 'antiviral': 2}
peptide_file = ''

class LengthSampler:
    """
    根据数据集长度来构建多项式分布，用于采样生成的长度，来模拟随机性
    """
    def __init__(self, path, max_len=254):
        def load_fasta_file(file_path):
            sequences = []
            with open(file_path, "r") as fasta_file:
                for record in SeqIO.parse(fasta_file, "fasta"):
                    sequences.append(str(record.seq))
            return sequences

        data = load_fasta_file(path)
        self.dataset_len = np.clip([len(t) for t in data], a_min=0, a_max=max_len)
        freqs = Counter(self.dataset_len)
        self.distrib = []
        for i in range(max_len + 1):
            self.distrib.append(freqs.get(i, 0))

        self.distrib = np.array(self.distrib) / np.sum(self.distrib)

    def sample(self, num_samples):
        s = np.argmax(np.random.multinomial(1, self.distrib, size=(num_samples)), axis=1)
        return s


def ddpm_sample(denoiser, x_t_shape, timesteps, betas, sqrt_alphas, alphas_cumprod, alphas_cumprod_prev, sqrt_alphas_cumprod_prev,
                len_list=None, use_attention=True, cfs=1., peptide_type='antimicrobial', return_attn_list=False, index=1):
    # 用于无条件扩散去噪的条件
    unconditional = torch.zeros(x_t_shape[0], device=device, dtype=torch.int) + 3
    # 指定生成的多肽的条件
    conditional = torch.zeros(x_t_shape[0], device=device, dtype=torch.int) + labels[peptide_type]
    attn_list = None
    if len_list is None:
        length_sampler = LengthSampler(path='../data/train/%s.fasta' % peptide_file, max_len=int(conf_dict['max_length']))
        # 随机生成attention mask，加2是因为要在头尾添加特殊标记
        ones_counts = length_sampler.sample(x_t_shape[0]) + 2
    else:
        ones_counts = np.array(len_list) + 2
    attention_mask = np.zeros((len(ones_counts), int(conf_dict['max_length']) + 2), dtype=int)
    for i, count in enumerate(ones_counts):
        attention_mask[i, :count] = 1
    attention_mask = torch.from_numpy(attention_mask).to(device)
    x_t_shape[0] = len(ones_counts)
    x_t = torch.randn(x_t_shape).to(device)
    ones = torch.ones_like(x_t)
    for t in tqdm(reversed(range(1, timesteps + 1)), desc='index %s sampling loop time step' % index, total=timesteps):
        batched_times = torch.full((x_t.shape[0],), t, device=device, dtype=torch.long)
        batched_times_pre = batched_times - 1
        alpha_cumprod_prev_t = extract(alphas_cumprod_prev, batched_times_pre, x_t_shape)
        sqrt_alpha_cumprod_prev_t = extract(sqrt_alphas_cumprod_prev, batched_times_pre, x_t_shape)
        beta_t = extract(betas, batched_times_pre, x_t_shape)
        sqrt_alpha_t = extract(sqrt_alphas, batched_times_pre, x_t_shape)
        alpha_cumprod_t = extract(alphas_cumprod, batched_times_pre, x_t_shape)
        if use_attention:
            if return_attn_list:
                conditional_pred_x0, attn_list = denoiser(x_t, batched_times, y=conditional, attention_mask=attention_mask, return_attn_matrix=True)
            else:
                conditional_pred_x0 = denoiser(x_t, batched_times, y=conditional, attention_mask=attention_mask)
            unconditional_pred_x0 = denoiser(x_t, batched_times, y=unconditional, attention_mask=attention_mask)
        else:
            if return_attn_list:
                conditional_pred_x0, attn_list = denoiser(x_t, batched_times, y=conditional, return_attn_matrix=True)
            else:
                conditional_pred_x0 = denoiser(x_t, batched_times, y=conditional)
            unconditional_pred_x0 = denoiser(x_t, batched_times, y=unconditional)
        pred_x0 = (1 + cfs) * conditional_pred_x0 - cfs * unconditional_pred_x0
        # 最后一个时间步无需计算均值和方差后进行重参数化
        if t == 1:
            x_t = pred_x0
            continue
        noise = torch.randn_like(x_t)
        # 已知x0下的后验的均值
        miu = sqrt_alpha_cumprod_prev_t * beta_t / torch.add(ones, -alpha_cumprod_t) * pred_x0 + \
              sqrt_alpha_t * torch.add(ones, -alpha_cumprod_prev_t) / torch.add(ones, -alpha_cumprod_t) * x_t
        # 后验的方差，使用上限方差公式
        sigma = torch.add(ones, -alpha_cumprod_prev_t) / torch.add(ones, -alpha_cumprod_t) * beta_t
        # 重参数化技巧
        x_t = miu + (0.5 * torch.log(sigma)).exp() * noise
    return (x_t, attn_list) if return_attn_list else x_t

In [7]:
# 加载训练好的去噪模型
denoiser_mlp = [int(i) for i in conf_dict['denoiser_mlp'].split(',')]
denoiser = Denoiser('../' + conf_dict['denoiser_esm_model_name'], int(conf_dict['denoiser_embedding']), denoiser_mlp).cuda()
denoiser.load_state_dict(torch.load("../save_model/denoise_model.pkl"))
tokenizer = AutoTokenizer.from_pretrained('../' + conf_dict['denoiser_esm_model_name'], trust_remote_code=True)
esm2_model = AutoModelForMaskedLM.from_pretrained('../' + conf_dict['denoiser_esm_model_name'], trust_remote_code=True).cuda()
decoder = esm2_model.lm_head
denoiser.eval()
esm2_model.eval()
decoder.eval()

Some weights of EsmModel were not initialized from the model checkpoint at ../ESM2-8M and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmLMHead(
  (dense): Linear(in_features=320, out_features=320, bias=True)
  (layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  (decoder): Linear(in_features=320, out_features=33, bias=False)
)

In [8]:
global peptide_file
if sample_type == 'amp':
    peptide_file = 'antimicrobial'
    train_dataset = XDataset(pd.read_csv('../data/train/antimicrobial.csv'))
    real_file = '../data/train/antimicrobial.fasta'
    decoder.load_state_dict(torch.load("../save_model/antimicrobial_decoder_model_1.pkl"))
elif sample_type == 'afp':
    peptide_file = 'antifungal'
    train_dataset = XDataset(pd.read_csv('../data/train/antifungal.csv'))
    real_file = '../data/train/antifungal.fasta'
    decoder.load_state_dict(torch.load("../save_model/antifungal_decoder_model_1.pkl"))
else:
    peptide_file = 'antiviral'
    train_dataset = XDataset(pd.read_csv('../data/train/antiviral.csv'))
    real_file = '../data/train/antiviral.fasta'
    decoder.load_state_dict(torch.load("../save_model/antiviral_decoder_model_1.pkl"))
if not os.path.exists('./sample_%s' % peptide_file):
    os.mkdir('./sample_%s' % peptide_file)

In [9]:
t = torch.linspace(1, int(conf_dict['time_steps']), int(conf_dict['time_steps']), device=device, dtype=torch.int)
alphas_cumprod = 1 - torch.sqrt(t / (int(conf_dict['time_steps']) + float(conf_dict['sqrt_s'])))
"""
求出β列表
"""
betas = []
# 记录已知的ā_t的乘项
one_minus_beta = 0
for i in range(int(conf_dict['time_steps'])):
    alphas_cumprod_t = alphas_cumprod[i]
    if i == 0:
        # 第一项可以直接求出来
        betas.append(1 - alphas_cumprod_t)
        one_minus_beta = alphas_cumprod_t
    else:
        # 1-β_t
        temp = alphas_cumprod_t / one_minus_beta
        betas.append(1 - temp)
        one_minus_beta = alphas_cumprod_t
betas = torch.stack(betas).to(device)
alphas = 1. - betas
sqrt_alphas = torch.sqrt(alphas)
# t - 1
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=0.0)
sqrt_alphas_cumprod_prev = torch.sqrt(alphas_cumprod_prev)

# 存储采样后的序列
seq_list = []
# 伪困惑度（Pseudo-Perplexity）列表
pseudo_perplexity_list = []
# 熵列表
entropy_list = []
# token键值对
vocab_dict = {v: k for k, v in tokenizer.get_vocab().items()}
seq_index = 1

In [10]:
with torch.no_grad():
    for i in range(n):
        x0 = ddpm_sample(denoiser, [sample_num, int(conf_dict['max_length']) + 2, int(conf_dict['denoiser_embedding'])], int(conf_dict['time_steps']), betas,
                         sqrt_alphas, alphas_cumprod, alphas_cumprod_prev, sqrt_alphas_cumprod_prev, cfs=cfs, use_attention=True, peptide_type=peptide_file,
                         index=(i + 1))
        pred_score = decoder(x0)

        x0 = x0.cpu().numpy()
        np.save('./sample_%s/sample_x0_%s' % (peptide_file, i), x0)
        seq_ids_list = pred_score.argmax(dim=-1)
        for index, seq_ids in zip(range(sample_num), seq_ids_list):
            # 跳过开头不是cls的样本
            if seq_ids[0] == 0:
                """
                跳过含有除了pad的特殊标记的，以及没有eos标记的序列
                """
                if not torch.any(torch.eq(seq_ids, 3)).item() and not torch.any(torch.eq(seq_ids, 29)).item() and \
                    not torch.any(torch.eq(seq_ids, 30)).item() and not torch.any(torch.eq(seq_ids, 31)).item() and \
                        torch.any(torch.eq(seq_ids, 2)).item():
                    eos_indexs = torch.nonzero(seq_ids == 2).squeeze()
                    if eos_indexs.numel() > 1:
                        eos_index = eos_indexs[0].item()
                    else:
                        eos_index = eos_indexs.item()
                    # 获取生成的序列
                    seq = tokenizer.decode(seq_ids[1:eos_index]).replace(" ", "").replace("<cls>", "").replace("<eos>", "").replace("<pad>", "")
                    seq_list.append(seq)

        total_sample = len(seq_list)
        print("有效样本数量为%s条" % total_sample)
        surplus_num = sample_num * n - total_sample
        de_duplication_sample_seq_set = set(seq_list)
        print("去重后的数量为%s条" % len(de_duplication_sample_seq_set))
        """
        由于生成模型的随机性，因此可能会采样到重复的样本和不正常的样本，故以下while循环为了确保生成的样本不重复以及达到指定数量
        """
        while total_sample != len(de_duplication_sample_seq_set) or surplus_num != 0:
            total_num = total_sample - len(de_duplication_sample_seq_set) + surplus_num
            print("继续采样%s条" % total_num)
            x0 = ddpm_sample(denoiser, [total_num, int(conf_dict['max_length']) + 2, int(conf_dict['denoiser_embedding'])], int(conf_dict['time_steps']), betas,
                             sqrt_alphas, alphas_cumprod, alphas_cumprod_prev, sqrt_alphas_cumprod_prev, cfs=cfs, use_attention=True, peptide_type=peptide_file,
                             index=1)
            pred_score = decoder(x0)
            seq_ids_list = pred_score.argmax(dim=-1)
            # 去重后的样本列表
            seq_list = list(de_duplication_sample_seq_set)
            num = 0
            for index, seq_ids in zip(range(total_num), seq_ids_list):
                if seq_ids[0] == 0:
                    if not torch.any(torch.eq(seq_ids, 3)).item() and not torch.any(torch.eq(seq_ids, 29)).item() and \
                            not torch.any(torch.eq(seq_ids, 30)).item() and not torch.any(torch.eq(seq_ids, 31)).item() and \
                            torch.any(torch.eq(seq_ids, 2)).item():
                        eos_indexs = torch.nonzero(seq_ids == 2).squeeze()
                        if eos_indexs.numel() > 1:
                            eos_index = eos_indexs[0].item()
                        else:
                            eos_index = eos_indexs.item()
                        seq = tokenizer.decode(seq_ids[1: eos_index]).replace(" ", "").replace("<cls>", "").replace("<eos>", "").replace("<pad>", "")
                        seq_list.append(seq)
                        num += 1
            print("有效样本数量为%s条" % num)
            surplus_num = sample_num * n - len(seq_list)
            de_duplication_sample_seq_set = set(seq_list)
        else:
            with open("./sample_%s/sample_result_metircs.fasta" % peptide_file, 'w') as f:
                for seq in seq_list:
                    length = len(seq)
                    # 将‘X’替换成随机选择的常见氨基酸
                    if 'X' in seq:
                        for i in range(length):
                            if seq[i] == 'X':
                                seq = seq.replace(seq[i], vocab_dict[random.randrange(4, 24)])
                    """
                    伪困惑度
                    """
                    tensor_input = tokenizer.encode(seq, return_tensors='pt')
                    repeat_input = tensor_input.repeat(tensor_input.size(-1) - 2, 1)
                    # 根据伪困惑度的公式，除 [CLS] 和 [SEP] 外，逐个屏蔽
                    mask = torch.ones(tensor_input.size(-1) - 1).diag(1)[:-2]
                    masked_input = repeat_input.masked_fill(mask == 1, tokenizer.mask_token_id)
                    # "-100"表示计算交叉熵时不计算该部分
                    labels = repeat_input.masked_fill(masked_input != tokenizer.mask_token_id, -100).to(device)
                    # esm2的loss默认是求平均后才返回的
                    loss = esm2_model(masked_input.to(device), labels=labels.to(device)).loss
                    pseudo_perplexity = loss.exp()
                    pseudo_perplexity_list.append(pseudo_perplexity)
                    """
                    信息熵
                    """
                    entropy_dic = defaultdict(int)
                    for amino in seq:
                        entropy_dic[amino] += 1
                    entropy = 0
                    for key in entropy_dic.keys():
                        entropy += -(entropy_dic[key] / length) * math.log2(entropy_dic[key] / length)
                    entropy_list.append(entropy)
                    f.write(">%s\n%s\n" % (seq_index, seq))
                    seq_index += 1

        pseudo_perplexity_list = torch.stack(pseudo_perplexity_list).to(device)
        entropy_list = torch.tensor(entropy_list, device=device)
        print("Pseudo-Perplexity:%s±%s" % (pseudo_perplexity_list.mean().item(), torch.sqrt(pseudo_perplexity_list.var()).item()))
        print("entropy:%s±%s" % (entropy_list.mean().item(), torch.sqrt(entropy_list.var()).item()))

index 1 sampling loop time step: 100%|██████████████████████████████████████████████████████████████████████| 2000/2000 [04:01<00:00,  8.28it/s]


有效样本数量为1000条
去重后的数量为999条
继续采样1条


index 1 sampling loop time step: 100%|██████████████████████████████████████████████████████████████████████| 2000/2000 [00:26<00:00, 76.53it/s]


有效样本数量为1条
Pseudo-Perplexity:10.580612182617188±4.615673065185547
entropy:2.400797128677368±0.7949443459510803


In [11]:
# 不稳定性得分
desc = GlobalDescriptor("./sample_%s/sample_result_metircs.fasta" % peptide_file)
desc.instability_index()
instability_score = desc.descriptor.squeeze()
print("Instability: %.4f±%.4f" % (mean(instability_score), stdev(instability_score)))

Instability: 37.9084±59.1458


In [None]:
# 相似性得分
def match_score(sample_seq, real_file):
    """
    相似性得分包括评估生成的多肽序列与相应多肽数据集中现有序列之间的比对得分。比对得分越低，说明生成的多肽序列越新颖
    """
    aligner = Align.PairwiseAligner()
    aligner.substitution_matrix = Align.substitution_matrices.load("BLOSUM62")

    score_list = []

    for record in SeqIO.parse(real_file, "fasta"):
        amp_str = record.seq
        alignments = aligner.align(amp_str, sample_seq)
        score = alignments.score

        score_list.append(score)

    score_list = np.stack(score_list)
    return sample_seq, score_list.mean()

sample_dataset = FastaBatchedDataset.from_file("./sample_%s/sample_result_metircs.fasta" % peptide_file)
match_score_list = []
sample_len = len(sample_dataset)
for i, data in zip(tqdm(range(0, sample_len), desc='Peptide similarity calculation:', total=sample_len), sample_dataset):
    seq, score = match_score(data[1], '../data/train/%s.fasta'% peptide_file)
    match_score_list.append(score)
print("Similarity: %.4f±%.4f" % (mean(match_score_list), stdev(match_score_list)))

Peptide similarity calculation::  85%|████████████████████████████████████████████████████████████▌          | 853/1000 [03:34<00:34,  4.31it/s]