In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sequence_embedding import SequenceToVectorModel
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, AutoModelForMaskedLM
from torch.cuda.amp import autocast

In [2]:
import os

BASE_DIR = os.getcwd()
BASE_DIR = os.path.dirname(BASE_DIR)



esm_model_path = os.path.join(BASE_DIR, "data", "esm2_t36_3B_UR50D1")
esm_weight_path = os.path.join(BASE_DIR, "data", "esm2_t36_3B_UR50D_mlm_finetuned.pth")
cls_model_path = os.path.join(BASE_DIR, "model", "weights2", "model_epoch_10.pth")
result_csv_path = os.path.join(BASE_DIR, "results", "predictions.csv")
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
AA_LIST = "ACDEFGHIKLMNPQRSTVWYU"

print(esm_model_path, esm_weight_path, cls_model_path, result_csv_path, device)

/public/home/kngll/Mamba_phase/data/esm2_t36_3B_UR50D1 /public/home/kngll/Mamba_phase/data/esm2_t36_3B_UR50D_mlm_finetuned.pth /public/home/kngll/Mamba_phase/model/weights2/model_epoch_10.pth /public/home/kngll/Mamba_phase/results/predictions.csv cuda:1


In [3]:
best_hyperparams = {
    'd_model': 256,
    'd_inner': 128,    # 原64改为128（训练配置）
    'n_ssm': 8,
    'dt_rank': 8,      # 原1改为8（训练配置）
    'n_layer': 1,
    'dropout': 0.1,    # 原0.15改为0.1（训练配置）
    'mlp_units': 1024,
    'dropout_rate': 0.3
}

In [4]:
def infer_esm_rep(model, tokenizer, sequence, device):
    encoded_inputs = tokenizer(sequence, return_tensors='pt', padding=True, truncation=True)
    encoded_inputs = {k: v.to(device) for k, v in encoded_inputs.items()}
    with torch.no_grad():
        with autocast():
            outputs = model(**encoded_inputs, output_hidden_states=True)
    representations = outputs.hidden_states[-1]
    last_hidden_state = representations[:, 0, :]
    torch.cuda.empty_cache()
    return last_hidden_state.squeeze(0).cpu()

sequences = [
    "MNRYLNRQRLYNMEEERNKYRGVMEPMSRMTMDFQGRYMDSQGRMVDPRYYDHYGRMHDYDRYYGRSMFNQGHSMDSQRYGGWMDNPERYMDMSGYQMDMQGRWMDAQGRYNNPFSQMWHSRQGHYPGEEEMSHHSMYGRNMHYPYHSHSASRHFDSPERWMDMSGYQMDMQGRWMDNYGRYVNPFHHHMYGRNMFYPYGSHCNNRHMEHPERYMDMSGYQMDMQGRWMDTHGRHCNPLGQMWHNRHGYYPGHPHGRNMFQPERWMDMSSYQMDMQGRWMDNYGRYVNPFSHNYGRHMNYPGGHYNYHHGRYMNHPERQMDMSGYQMDMHGRWMDNQGRYIDNFDRNYYDYHMY",
    # 可添加更多序列
]

tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
esm_model = AutoModelForMaskedLM.from_pretrained(esm_model_path)
esm_model.load_state_dict(torch.load(esm_weight_path), strict=False)
esm_model = esm_model.to(device)

print(f"共{len(sequences)}条序列，计算esm表示...")
esm_reps = []
for seq in sequences:
    if len(seq) > 4000:
        seq = seq[:4000]
    try:
        rep = infer_esm_rep(esm_model, tokenizer, seq, device)
        esm_reps.append(rep)
    except torch.cuda.OutOfMemoryError:
        print("OOM error! 忽略序列: ", seq[:10], "...")
        torch.cuda.empty_cache()
        esm_reps.append(torch.zeros(2560, dtype=torch.float))

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  esm_model.load_state_dict(torch.load(esm_weight_path), strict=False)
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


共1条序列，计算esm表示...


  with autocast():


In [6]:
from torch.utils.data import Dataset, DataLoader
class PredictionDatasetV2(Dataset):
    def __init__(self, sequences, esm_reps, amino_acid_to_index):
        self.sequences = sequences
        self.esm_reps = esm_reps
        self.amino_acid_to_index = amino_acid_to_index

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        encoded_seq = torch.tensor([self.amino_acid_to_index.get(aa, 0) for aa in seq], dtype=torch.long)
        esm_rep = self.esm_reps[idx]
        return encoded_seq, esm_rep, seq
def collate_fn_predict(batch):
    seqs, esm_reps, origs = zip(*batch)
    padded_seqs = pad_sequence(seqs, batch_first=True, padding_value=0)
    esm_tensor = torch.stack([r if r.ndim==1 else r.squeeze(0) for r in esm_reps])
    return padded_seqs, esm_tensor, origs
amino_acid_to_index = {aa: idx for idx, aa in enumerate(AA_LIST)}
print("准备分类数据...")
dataset = PredictionDatasetV2(sequences, esm_reps, amino_acid_to_index)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn_predict, shuffle=False)

准备分类数据...


In [9]:
class ClassificationModel(nn.Module):
    def __init__(self, sequence_model, mlp_units=512, dropout_rate=0.4):
        super().__init__()
        self.sequence_model = sequence_model
        
        # 增强的MLP结构（4层）
        self.mlp1 = nn.Sequential(
            nn.Linear(2560, mlp_units),
            nn.BatchNorm1d(mlp_units),
            nn.SiLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(mlp_units, mlp_units//2),
            nn.BatchNorm1d(mlp_units//2),
            nn.SiLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(mlp_units//2, mlp_units//4),
            nn.BatchNorm1d(mlp_units//4),
            nn.SiLU(),
            nn.Dropout(dropout_rate)
        )
        
        # 分类器（2层）
        self.classifier = nn.Sequential(
            nn.Linear(mlp_units//4 + 256, 512),
            nn.BatchNorm1d(512),
            nn.SiLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(512, 2)
        )

    def forward(self, input_ids, rdict_seqs):
        embeddings = self.sequence_model(input_ids)
        rdict_emb = self.mlp1(rdict_seqs)
        combined = torch.cat([rdict_emb, embeddings], dim=1)
        return self.classifier(combined)

def load_cls_model(cls_model_path):
    sequence_model = SequenceToVectorModel(
        vocab_size=len(AA_LIST),
        d_model=best_hyperparams['d_model'],
        d_inner=best_hyperparams['d_inner'],
        n_ssm=best_hyperparams['n_ssm'],
        dt_rank=best_hyperparams['dt_rank'],
        n_layer=best_hyperparams['n_layer'],
        dropout=best_hyperparams['dropout'],
        output_dim=256
    )
    # 关键修改2：显式传递分类模型参数
    model = ClassificationModel(
        sequence_model,
        mlp_units=best_hyperparams['mlp_units'],
        dropout_rate=best_hyperparams['dropout_rate']
    ).to(device)
    # 关键修改3：设置weights_only=True并处理加载
    state_dict = torch.load(cls_model_path, map_location=device, weights_only=True)
    if any(k.startswith("module.") for k in state_dict):
        new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    else:
        new_state_dict = state_dict
    model.load_state_dict(new_state_dict, strict=False)
    model.eval()
    return model

print("加载下游分类模型...")
cls_model = load_cls_model(cls_model_path)

加载下游分类模型...


In [11]:
print("进行预测...")
results = []
with torch.no_grad():
    for seqs, esm_reps, origs in loader:
        seqs = seqs.to(device)
        esm_reps = esm_reps.to(device)
        outputs = cls_model(seqs, esm_reps)
        probs = torch.softmax(outputs, dim=1)
        for i in range(len(origs)):
            results.append({
                "sequence": origs[i],
                "prob_0": probs[i][0].item(),
                "prob_1": probs[i][1].item(),
                "prediction": torch.argmax(probs[i]).item()
            })

df = pd.DataFrame(results)
os.makedirs(os.path.dirname(result_csv_path), exist_ok=True)
df.to_csv(result_csv_path, index=False)
print(f"已写入: {result_csv_path}")


进行预测...
已写入: /public/home/kngll/Mamba_phase/results/predictions.csv
