In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, AutoModelForMaskedLM
from torch.cuda.amp import autocast
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM

current_dir = os.current_dir = os.path.abspath("")


# esm_model_path = os.path.join(current_dir,"data", "esm2_t36_3B_UR50D")
# esm_weight_path = os.path.join(current_dir, "data","esm2_t36_3B_UR50D_mlm_finetuned.pth")

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
model = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t36_3B_UR50D")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
AA_LIST = "ACDEFGHIKLMNPQRSTVWYU"

sequences = [
    "GHGVYGHGVYGHGPYGHGPYGHGLYW",
]

amino_acid_to_index = {aa: idx for idx, aa in enumerate(AA_LIST)}

def infer_esm_rep(model, tokenizer, sequence, device):
    encoded_inputs = tokenizer(sequence, return_tensors='pt', padding=True, truncation=True)
    encoded_inputs = {k: v.to(device) for k, v in encoded_inputs.items()}
    with torch.no_grad():
        with autocast():
            outputs = model(**encoded_inputs, output_hidden_states=True)
    representations = outputs.hidden_states[-1]
    last_hidden_state = representations[:, 0, :]
    torch.cuda.empty_cache()
    return last_hidden_state.squeeze(0).cpu()

# tokenizer = AutoTokenizer.from_pretrained(esm_model_path)
# esm_model = AutoModelForMaskedLM.from_pretrained(esm_model_path)
esm_model.load_state_dict(torch.load(esm_weight_path), strict=False)
esm_model = esm_model.to(DEVICE)

print(f"共{len(sequences)}条序列，计算esm表示...")
esm_reps = []
for seq in sequences:
    if len(seq) > 4000:
        seq = seq[:4000]
    try:
        rep = infer_esm_rep(esm_model, tokenizer, seq, DEVICE)
        esm_reps.append(rep)
    except torch.cuda.OutOfMemoryError:
        print("OOM error! 忽略序列: ", seq[:10], "...")
        torch.cuda.empty_cache()
        esm_reps.append(torch.zeros(2560, dtype=torch.float))

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  esm_model.load_state_dict(torch.load(esm_weight_path), strict=False)
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


共1条序列，计算esm表示...


  with autocast():


In [8]:
#psp calculation
import sys
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
current_dir = os.current_dir = os.path.abspath("")
psp_dir = os.path.join(current_dir, "psp")

model_path = os.path.join(current_dir, "model", "pspweight", "best_model.pth")

sys.path.insert(0, psp_dir) 

from infer import Predictor

predictor = Predictor(
    model_path=model_path,
    device=device
)

# 执行预测
pspprobs = predictor.predict(esm_reps)

In [12]:
#scafold&client calculation

current_dir = os.current_dir = os.path.abspath("")
sca_dir = os.path.join(current_dir, "cli2scafold")

model_path = os.path.join(current_dir, "model", "clientweight", "best_model.pth")

sys.path.insert(0, sca_dir) 

from inferCli import CliPredictor

predictor = CliPredictor(
    model_path=model_path,
    device=device
)

# 执行预测
pspprobs = predictor.predict(esm_reps)

In [13]:
#ph calculation
current_dir = os.current_dir = os.path.abspath("")
ph_dir = os.path.join(current_dir, "phsalt")


model_path = os.path.join(current_dir, "model", "phweight", "best_model.pth")

sys.path.insert(0, ph_dir)

from phinfer import PHpredictor

predictor = PHpredictor(
    model_path=model_path,
)

# 执行预测
phprobs = predictor.predict(esm_reps)

  self.model.load_state_dict(torch.load(model_path, map_location=self.device))


In [15]:



#salt calculation
current_dir = os.current_dir = os.path.abspath("")
salt_dir = os.path.join(current_dir, "phsalt")

model_path = os.path.join(current_dir, "model", "saltweight", "best_model.pth")

sys.path.insert(0, salt_dir)

from saltinfer import saltPredictor
predictor = saltPredictor(
    model_path=model_path,
)
# 执行预测
saltprobs = predictor.predict(esm_reps)

  self.model.load_state_dict(torch.load(model_path, map_location=self.device))
