In [1]:
import os
os.getcwd()

'/home/a1356913256/python/molecular_generation/CPL-Diff/scripts'

In [2]:
import torch
import configparser
import pandas as pd
import torch.nn.functional as F
from CPLDiff.models.Denoiser import Denoiser
from CPLDiff.models.layers.EMA import EMA
from CPLDiff.utils.CPLDiffDataset import XYDataset
from CPLDiff.utils.utils import set_seed
from CPLDiff.utils.utils import extract
from transformers import EsmTokenizer, EsmModel
from timm.scheduler.cosine_lr import CosineLRScheduler
from colorama import Fore, Style
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torch import device
from torch.cuda import is_available

In [3]:
# 获取可以使用的硬件资源
device = device("cuda:0" if is_available() else "cpu")
device

device(type='cuda', index=0)

In [4]:
"""读取配置文件"""
conf = configparser.ConfigParser()
conf.read('../config.ini')
conf_dict = dict(conf.items('CPLDiff_conf'))
float(conf_dict['denoiser_lr'])

0.0005

In [5]:
set_seed(int(conf_dict['seed']))
if not os.path.exists('../save_model'):
    os.mkdir('../save_model')

In [6]:
"""设置扩散模型的噪声计划表"""
# 《Diffusion-LM Improves Controllable Text Generation》采取如下时间表形式
t = torch.arange(1, int(conf_dict['time_steps']) + 1, dtype=torch.long, device=device)
# 求上头带横线的α
alphas_cumprod = 1 - torch.sqrt(t / (int(conf_dict['time_steps']) + float(conf_dict['sqrt_s'])))
# 求上头带横线的α，然后开根号
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
# 求1减去上头带横线的α，然后开根号
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

In [7]:
# token编码器
tokenizer = EsmTokenizer.from_pretrained('../' + conf_dict['denoiser_esm_model_name'])
# 用于获取蛋白质潜空间嵌入
esm2_model = EsmModel.from_pretrained('../' + conf_dict['denoiser_esm_model_name'], add_pooling_layer=True).to(device)
esm2_model.eval()

Some weights of EsmModel were not initialized from the model checkpoint at ../ESM2-8M and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 320, padding_idx=1)
    (dropout): Dropout(p=0.2, inplace=False)
    (position_embeddings): Embedding(1026, 320, padding_idx=1)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0): EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=320, out_features=320, bias=True)
            (key): Linear(in_features=320, out_features=320, bias=True)
            (value): Linear(in_features=320, out_features=320, bias=True)
            (dropout): Dropout(p=0.2, inplace=False)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=320, out_features=320, bias=True)
            (dropout): Dropout(p=0.2, inplace=False)
          )
          (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
          

In [8]:
# 加载数据集
train_dataset = XYDataset(pd.read_csv('../data/train/mulit_peptide_train.csv'))
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=int(conf_dict['diffusion_batch_size']), pin_memory=True, shuffle=True,
                                                persistent_workers=True, num_workers=8)
val_dataset = XYDataset(pd.read_csv('../data/train/mulit_peptide_val.csv'))
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1024, pin_memory=True, persistent_workers=True, num_workers=8)

In [9]:
# 去噪模型
denoiser_mlp = [int(i) for i in conf_dict['denoiser_mlp'].split(',')]
denoiser = Denoiser('../' + conf_dict['denoiser_esm_model_name'], int(conf_dict['denoiser_embedding']), denoiser_mlp).to(device)
optimizer = torch.optim.AdamW(denoiser.parameters(), lr=float(conf_dict['denoiser_lr']), weight_decay=float(conf_dict['denoiser_weight_decay']))
scheduler = CosineLRScheduler(optimizer, t_initial=200_000, lr_min=float(conf_dict['min_denoiser_lr']), warmup_lr_init=1e-8, warmup_t=10_000, cycle_limit=1,
                              t_in_epochs=False)
# 初始化EMA
ema = EMA(denoiser, 0.99)
ema.register()
denoiser

Some weights of EsmModel were not initialized from the model checkpoint at ../ESM2-8M and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Denoiser(
  (time_emb): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=640, out_features=1280, bias=True)
    (2): SiLU()
    (3): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (label_emb): LabelEmbedder(
    (embedding_table): Embedding(4, 1280)
  )
  (esm_attention_list): ModuleList(
    (0): EsmAttention(
      (self): EsmSelfAttention(
        (query): Linear(in_features=320, out_features=320, bias=True)
        (key): Linear(in_features=320, out_features=320, bias=True)
        (value): Linear(in_features=320, out_features=320, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (rotary_embeddings): RotaryEmbedding()
      )
      (output): EsmSelfOutput(
        (dense): Linear(in_features=320, out_features=320, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
    (1): EsmAttention(
      (self): EsmSelfAttention(
      

In [10]:
print("The total number of iteration steps is " + Fore.LIGHTGREEN_EX + "%s" % int(conf_dict['denoiser_epoch']) + Style.RESET_ALL + ".")
epoch = 0
x_steps = 1
# 检查是否存在之前的检查点文件
if os.path.exists("../save_model/checkpoint.pth"):
    # 加载模型状态
    checkpoint = torch.load("../save_model/checkpoint.pth")
    denoiser.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    ema = checkpoint['ema']
    # 由于保存的是已经训练好的epoch，因此从下一个epoch开始
    epoch = checkpoint['epoch'] + 1
    x_steps = checkpoint['x_steps']
    print(f"Continuing from epoch {epoch}")
    del checkpoint
writer = SummaryWriter(log_dir='./logs')

The total number of iteration steps is [92m680[0m.


In [None]:
pbar = tqdm(total=len(train_data_loader), miniters=0)
end_train = False
for i in range(epoch, int(conf_dict['denoiser_epoch'])):
    if i % 100 == 0:
        set_seed(int(conf_dict['seed']) + i)
    """
    为避免图像重叠，每隔一定的epoch创建一个log文件;
    如果是从非设定的阈值的倍数的epoch中断的，则删掉最新的log文件，避免重叠；
    为了保证能被阈值整除的记录在旧文件里，因此无需加1求余
    """
    if epoch % 20 == 0:
        writer = SummaryWriter(log_dir='./logs')
    denoiser.train()
    for index, datas in enumerate(train_data_loader):
        # 获取序列的索引列表，加2是因为会在头尾额外插入开始和结束标记
        seq_encode = tokenizer(datas['sequences'], max_length=int(conf_dict['max_length']) + 2, padding='max_length', return_tensors="pt")
        seq_ids_list, train_attention_mask = seq_encode['input_ids'].to(device), seq_encode['attention_mask'].to(device)
        # 获取蛋白质潜空间嵌入
        with torch.no_grad():
            x0 = esm2_model(seq_ids_list, train_attention_mask).last_hidden_state

        t = torch.randint(1, int(conf_dict['time_steps']) + 1, (x0.shape[0],), device=device, dtype=torch.long)
        # 随机采样噪声
        noise = torch.randn_like(x0, device=device)
        # 按照给定的时间步取对应的系数，减1是因为下标是从0开始的
        sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t - 1, x0.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t - 1, x0.shape)
        # 第一项，sqrt(ā_t)*x_0
        term1 = sqrt_alphas_cumprod_t * x0
        # 第二项，噪声乘上sqrt(1-ā_t)
        term2 = sqrt_one_minus_alphas_cumprod_t * noise
        # 加噪完毕后的数据
        x_t = torch.add(term1, term2)

        optimizer.zero_grad()
        # 预测x0
        pred_x0 = denoiser(x_t, t, y=datas['labels'].to(device), attention_mask=train_attention_mask)
        x0_loss = F.mse_loss(pred_x0, x0)
        x0_loss.backward()
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(denoiser.parameters(), max_norm=float(conf_dict['denoiser_clip_grad']))
        optimizer.step()

        if x_steps % 20 == 0:
            pbar.set_description_str("epoch:%s   " % (epoch + 1) + Fore.LIGHTRED_EX + "loss:%.4f" % x0_loss.item() + Style.RESET_ALL)
            writer.add_scalar("train/x0_loss", x0_loss.item(), x_steps)
        # 指数移动平均
        ema.update()

        x_steps += 1
        pbar.update()
        scheduler.step_update(x_steps)

    pbar.reset()
    pbar.clear()

    val_x0_loss_list = []
    # 验证模型对于去噪的性能
    with torch.no_grad():
        # eval前，将影子权重应用到模型中
        ema.apply_shadow()
        denoiser.eval()
        for val_datas in val_data_loader:
            val_seq_encode = tokenizer(val_datas['sequences'], max_length=int(conf_dict['max_length']) + 2, padding='max_length', return_tensors="pt")
            val_seq_ids_list, val_attention_mask = val_seq_encode['input_ids'].to(device), val_seq_encode['attention_mask'].to(device)
            val_x0 = esm2_model(val_seq_ids_list, val_attention_mask).last_hidden_state
            val_t = torch.randint(0, int(conf_dict['time_steps']), (val_x0.shape[0],), device=device).long()

            val_noise = torch.randn_like(val_x0, device=device)
            val_sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, val_t, val_x0.shape)
            val_sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, val_t, val_x0.shape)
            val_term1 = val_sqrt_alphas_cumprod_t * val_x0
            val_term2 = val_sqrt_one_minus_alphas_cumprod_t * val_noise
            val_x_t = torch.add(val_term1, val_term2)

            val_pred_x0 = denoiser(val_x_t, val_t, y=val_datas['labels'].to(device), attention_mask=val_attention_mask)
            val_x0_loss_list.append(F.mse_loss(val_pred_x0, val_x0))

        val_x0_loss = torch.stack(val_x0_loss_list).mean()
        writer.add_scalar("val/x0_loss", val_x0_loss.item(), (epoch + 1))

        if (epoch + 1) % 20 == 0:
            torch.save(denoiser.state_dict(), "../save_model/denoise_model.pkl")
        state = {
            'model_state_dict': denoiser.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'ema': ema,
            'epoch': epoch,
            'x_steps': x_steps
        }
        torch.save(state, "../save_model/checkpoint.pth")
        epoch += 1
        # eval之后，恢复原来模型的参数
        ema.restore()
        if end_train:
            break

    # 最后保存的模型选择应用EMA后的
    ema.apply_shadow()
    torch.save(denoiser.state_dict(), "../save_model/denoise_model.pkl")

epoch:129   [91mloss:0.0154[0m:  97%|██████████████████████████████████████████████████████▍ | 286/294 [00:14<00:00, 21.92it/s][0m