In [26]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset, DistributedSampler
from torch.nn.utils.rnn import pad_sequence
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import os
import pickle
from Bio import SeqIO
import pandas as pd
import random
import numpy as np
import torch.distributed as dist
import torch.multiprocessing as mp
from sequence_embedding import SequenceToVectorModel

hyperparams = {
    'learning_rate': 0.0003,
    'batch_size': 24,
    'mlp_units': 128,
    'dropout_rate': 0.1,
    'num_epochs': 20,
    'd_model': 4,
    'd_inner': 48,
    'n_ssm': 1,
    'dt_rank': 1,
    'n_layer': 1,
    'dropout': 0.15,
    'n_splits': 4,  # actually not used now
    'max_seq_length': 4000
}

amino_acid_to_index = {aa: idx for idx, aa in enumerate("ACDEFGHIKLMNPQRSTVWY")}
device = torch.device("cuda")
rdict_path = "/public/home/kngll/Mambaphase/model/merged_dict.pkl"

max_seq_length = hyperparams['max_seq_length']

sequence_model_params = {
            'd_model': hyperparams['d_model'],
            'd_inner': hyperparams['d_inner'],
            'n_ssm': hyperparams['n_ssm'],
            'dt_rank': hyperparams['dt_rank'],
            'vocab_size': len("ACDEFGHIKLMNPQRSTVWYU"),
            'n_layer': hyperparams['n_layer'],
            'n_heads': 4,  # 由于超参数未指定n_heads，这里保持原值，如需更改请补充
            'output_dim': 192,  # d_inner 一般做output_dim更合适
            'dropout': hyperparams['dropout']
        }






In [27]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
from torch.nn.utils.rnn import pad_sequence


esm_weight_path = "/public/home/kngll/Mambaphase/data/esm2_t36_3B_UR50D_mlm_finetuned.pth"
esm_model_path = "/public/home/kngll/llps/data/esm2_t36_3B_UR50D"


def custom_collate_fn(batch):
    sequences, labels, rdict_seqs = zip(*batch)
    sequences_padded = pad_sequence(sequences, batch_first=True, padding_value=0)
    labels = torch.stack(labels, 0)
    max_length = max(seq.size(0) for seq in rdict_seqs)
    padded_rdict_seqs = []
    for seq in rdict_seqs:
        padded_seq = torch.zeros(max_length)
        padded_seq[:seq.size(0)] = seq
        padded_rdict_seqs.append(padded_seq)
    rdict_seqs_padded = torch.stack(padded_rdict_seqs, 0)
    return sequences_padded, labels, rdict_seqs_padded


def infer_esm_rep(model, tokenizer, sequence, device):
    encoded_inputs = tokenizer(sequence, return_tensors='pt', padding=True, truncation=True)
    encoded_inputs = {k: v.to(device) for k, v in encoded_inputs.items()}
    with torch.no_grad():
        with autocast():
            outputs = model(**encoded_inputs, output_hidden_states=True)
    representations = outputs.hidden_states[-1]
    last_hidden_state = representations[:, 0, :]
    torch.cuda.empty_cache()
    return last_hidden_state.squeeze(0).cpu()

sequences = [
    "MNRYLNRQRLYNMEEERNKYRGVMEPMSRMTMDFQGRYMDSQGRMVDPRYYDHYGRMHDYDRYYGRSMFNQGHSMDSQRYGGWMDNPERYMDMSGYQMDMQGRWMDAQGRYNNPFSQMWHSRQGHYPGEEEMSHHSMYGRNMHYPYHSHSASRHFDSPERWMDMSGYQMDMQGRWMDNYGRYVNPFHHHMYGRNMFYPYGSHCNNRHMEHPERYMDMSGYQMDMQGRWMDTHGRHCNPLGQMWHNRHGYYPGHPHGRNMFQPERWMDMSSYQMDMQGRWMDNYGRYVNPFSHNYGRHMNYPGGHYNYHHGRYMNHPERQMDMSGYQMDMHGRWMDNQGRYIDNFDRNYYDYHMY",
    # 可添加更多序列
]

tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
esm_model = AutoModelForMaskedLM.from_pretrained(esm_model_path)
esm_model.load_state_dict(torch.load(esm_weight_path), strict=False)
esm_model = esm_model.to(device)

print(f"共{len(sequences)}条序列，计算esm表示...")
esm_reps = []
for seq in sequences:
    if len(seq) > 4000:
        seq = seq[:4000]
    try:
        rep = infer_esm_rep(esm_model, tokenizer, seq, device)
        esm_reps.append(rep)
    except torch.cuda.OutOfMemoryError:
        print("OOM error! 忽略序列: ", seq[:10], "...")
        torch.cuda.empty_cache()
        esm_reps.append(torch.zeros(2560, dtype=torch.float))

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  esm_model.load_state_dict(torch.load(esm_weight_path), strict=False)
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


共1条序列，计算esm表示...


  with autocast():


In [31]:
class ClassificationModel(nn.Module):
    def __init__(self, sequence_model, rdict_dim=2560, mlp_units=384, dropout_rate=0.2, seq_emb_dim=192):
        super().__init__()
        self.sequence_model = sequence_model
        self.mlp1 = nn.Sequential(
            nn.Linear(rdict_dim, mlp_units),
            nn.ReLU(),
            nn.Linear(mlp_units, seq_emb_dim),
            nn.ReLU()
        )
        self.classifier = nn.Sequential(
            nn.Linear(seq_emb_dim * 2, mlp_units),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(mlp_units, 2)
        )
    def forward(self, input_ids, rdict_seqs):
        embeddings = self.sequence_model(input_ids)
        rdict_embeddings = self.mlp1(rdict_seqs)
        combined_embeddings = torch.cat((rdict_embeddings, embeddings), dim=1)
        logits = self.classifier(combined_embeddings)
        return logits
AA_LIST = "ACDEFGHIKLMNPQRSTVWYU"
def load_cls_model(cls_model_path):
    sequence_model = SequenceToVectorModel(
        vocab_size=len(AA_LIST),
        d_model=hyperparams['d_model'],
        d_inner=hyperparams['d_inner'],
        n_ssm=hyperparams['n_ssm'],
        dt_rank=hyperparams['dt_rank'],
        n_layer=hyperparams['n_layer'],
        dropout=hyperparams['dropout'],
        output_dim=192,
        n_heads=4
    )
    # 关键修改2：显式传递分类模型参数
    model = ClassificationModel(
        sequence_model=sequence_model,
        mlp_units=hyperparams['mlp_units'],
        dropout_rate=hyperparams['dropout_rate']
    ).to(device)
    # 关键修改3：设置weights_only=True并处理加载
    state_dict = torch.load(cls_model_path, map_location=device, weights_only=True)
    if any(k.startswith("module.") for k in state_dict):
        new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    else:
        new_state_dict = state_dict
    model.load_state_dict(new_state_dict, strict=False)
    model.eval()
    return model

cls_model_path = "/public/home/kngll/Mambaphase/data/modelscaffold/model_epoch19.pth"

print("加载下游分类模型...")
cls_model = load_cls_model(cls_model_path)

加载下游分类模型...


In [32]:
from torch.utils.data import Dataset, DataLoader
class PredictionDatasetV2(Dataset):
    def __init__(self, sequences, esm_reps, amino_acid_to_index):
        self.sequences = sequences
        self.esm_reps = esm_reps
        self.amino_acid_to_index = amino_acid_to_index

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        encoded_seq = torch.tensor([self.amino_acid_to_index.get(aa, 0) for aa in seq], dtype=torch.long)
        esm_rep = self.esm_reps[idx]
        return encoded_seq, esm_rep, seq
def collate_fn_predict(batch):
    seqs, esm_reps, origs = zip(*batch)
    padded_seqs = pad_sequence(seqs, batch_first=True, padding_value=0)
    esm_tensor = torch.stack([r if r.ndim==1 else r.squeeze(0) for r in esm_reps])
    return padded_seqs, esm_tensor, origs

print("准备分类数据...")
dataset = PredictionDatasetV2(sequences, esm_reps, amino_acid_to_index)
loader = DataLoader(dataset, batch_size=24, collate_fn=collate_fn_predict, shuffle=False)

准备分类数据...


In [33]:
print("进行预测...")
results = []
result_csv_path = "/public/home/kngll/Mambaphase/results/predictionclis.csv"
with torch.no_grad():
    for seqs, esm_reps, origs in loader:
        seqs = seqs.to(device)
        esm_reps = esm_reps.to(device)
        outputs = cls_model(seqs, esm_reps)
        probs = torch.softmax(outputs, dim=1)
        for i in range(len(origs)):
            results.append({
                "sequence": origs[i],
                "prob_0": probs[i][0].item(),
                "prob_1": probs[i][1].item(),
                "prediction": torch.argmax(probs[i]).item()
            })

df = pd.DataFrame(results)
os.makedirs(os.path.dirname(result_csv_path), exist_ok=True)
df.to_csv(result_csv_path, index=False)
print(f"已写入: {result_csv_path}")


进行预测...
已写入: /public/home/kngll/Mambaphase/results/predictionclis.csv
