In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, AutoModelForMaskedLM
from torch.cuda.amp import autocast

# ------------------- 参数配置 --------------------
esm_model_path="/public/home/kngll/llps/data/esm2_t36_3B_UR50D"
esm_weight_path="/public/home/kngll/Mambaphase/data/esm2_t36_3B_UR50D_mlm_finetuned.pth"
cls_model_path = "/public/home/kngll/Mambaphase/model/saltweight/best_model.pth"
result_csv_path="/public/home/kngll/Mambaphase/results/predictions.csv"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ESM_WEIGHT_PATH = "/public/home/kngll/Mambaphase/data/esm2_t36_3B_UR50D_mlm_finetuned.pth"
# ESM_MODEL_PATH = "/public/home/kngll/llps/data/esm2_t36_3B_UR50D"
# CLS_MODEL_PATH = "/public/home/kngll/Mambaphase/model/weights2/model_epoch_5.pth"
# RESULT_CSV_PATH = "/public/home/kngll/Mambaphase/results/predictions.csv"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
AA_LIST = "ACDEFGHIKLMNPQRSTVWYU"


amino_acid_to_index = {aa: idx for idx, aa in enumerate(AA_LIST)}

def infer_esm_rep(model, tokenizer, sequence, device):
    encoded_inputs = tokenizer(sequence, return_tensors='pt', padding=True, truncation=True)
    encoded_inputs = {k: v.to(device) for k, v in encoded_inputs.items()}
    with torch.no_grad():
        with autocast():
            outputs = model(**encoded_inputs, output_hidden_states=True)
    representations = outputs.hidden_states[-1]
    last_hidden_state = representations[:, 0, :]
    torch.cuda.empty_cache()
    return last_hidden_state.squeeze(0).cpu()

sequences = [
    "GHGVYGHGVYGHGPYGHGPYGHGLYW",
]



tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
esm_model = AutoModelForMaskedLM.from_pretrained(esm_model_path)
esm_model.load_state_dict(torch.load(esm_weight_path), strict=False)
esm_model = esm_model.to(DEVICE)

print(f"共{len(sequences)}条序列，计算esm表示...")
esm_reps = []
for seq in sequences:
    if len(seq) > 4000:
        seq = seq[:4000]
    try:
        rep = infer_esm_rep(esm_model, tokenizer, seq, DEVICE)
        esm_reps.append(rep)
    except torch.cuda.OutOfMemoryError:
        print("OOM error! 忽略序列: ", seq[:10], "...")
        torch.cuda.empty_cache()
        esm_reps.append(torch.zeros(2560, dtype=torch.float))

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  esm_model.load_state_dict(torch.load(esm_weight_path), strict=False)
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


共1条序列，计算esm表示...


  with autocast():


In [3]:
# 已经有一个提取出来的特征列表“esm_reps ”，将/public/home/kngll/Mambaphase/model/phweight/best_model.pth模型导入。写出推断的代码，将结果保存到/public/home/kngll/Mambaphase/results/predictions.csv文件中。
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from torch import nn

class MLPClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.net(x)

def predict_and_save(esm_reps, model_path, save_path):
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 确保保存目录存在
    Path(save_path).parent.mkdir(parents=True, exist_ok=True)
    
    try:
        # 加载模型
        model = MLPClassifier(2560,3).to(device)
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()
        
        # 转换输入数据
        if not isinstance(esm_reps, list):
            esm_reps = [esm_reps]
            
        # 预处理特征
        processed_features = []
        for feat in esm_reps:
            if isinstance(feat, np.ndarray):
                tensor_feat = torch.FloatTensor(feat)
            elif isinstance(feat, torch.Tensor):
                tensor_feat = feat.float()
            else:
                raise ValueError("Unsupported feature type")
                
            # 检查特征维度
            if tensor_feat.dim() == 1:
                tensor_feat = tensor_feat.unsqueeze(0)  # 添加batch维度
            elif tensor_feat.dim() != 2:
                raise ValueError(f"Invalid feature dimension: {tensor_feat.shape}")
                
            processed_features.append(tensor_feat)
            
        # 合并所有特征
        batch_data = torch.cat(processed_features).to(device)
        
        # 执行预测
        with torch.no_grad():
            outputs = model(batch_data)
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
        
        # 转换为numpy
        preds_np = preds.cpu().numpy()
        probs_np = probs.cpu().numpy()
        
        # 创建结果DataFrame
        results = pd.DataFrame({
            "prediction": preds_np,
            "prob_low": probs_np[:, 0],
            "prob_mid": probs_np[:, 1],
            "prob_high": probs_np[:, 2]
        })
        
        # 保存结果
        results.to_csv(save_path, index=False)
        print(f"Successfully saved predictions to {save_path}")
        
        return results
    
    except Exception as e:
        print(f"Error during prediction: {str(e)}")
        raise

# 使用示例
if __name__ == "__main__":
    # 假设 esm_reps 是预先加载的特征列表
    # 每个特征应为形状 (2560,) 的tensor或numpy数组
    
    # 模型路径
    MODEL_PATH = "/public/home/kngll/Mambaphase/model/saltweight/best_model.pth"
    
    # 保存路径
    SAVE_PATH = "/public/home/kngll/Mambaphase/results/predictions.csv"
    
    # 执行预测
    predictions = predict_and_save(
        esm_reps=esm_reps,
        model_path=MODEL_PATH,
        save_path=SAVE_PATH
    )

  model.load_state_dict(torch.load(model_path, map_location=device))


Successfully saved predictions to /public/home/kngll/Mambaphase/results/predictions.csv
