In [None]:
import torch
import torch.nn as nn
from transformers import EsmModel, EsmTokenizer
import warnings

In [None]:
model_name = "facebook/esm2_t33_650M_UR50D"

esm = EsmModel.from_pretrained(
    model_name,
    add_pooling_layer=False  # 关键：不使用pooler层
)



In [None]:
tokenizer = EsmTokenizer.from_pretrained(model_name)
sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAED"

inputs = tokenizer(
    sequence,
    return_tensors="pt",
    padding=True,
    truncation=True,
    max_length=1024
)

In [None]:


# 忽略警告
warnings.filterwarnings("ignore")

class PPIModel(nn.Module):
    def __init__(self, model_name="facebook/esm2_t33_650M_UR50D", hidden_dim=512, dropout=0.1):
        super(PPIModel, self).__init__()
        
        # 加载ESM模型，不添加pooling层
        self.esm = EsmModel.from_pretrained(
            model_name, 
            add_pooling_layer=False  # 关键：不使用pooler层
        )
        self.tokenizer = EsmTokenizer.from_pretrained(model_name)
        
        esm_dim = self.esm.config.hidden_size
        
        # 自定义的交互分类器
        self.classifier = nn.Sequential(
            nn.Linear(esm_dim * 2, hidden_dim),  # 两个蛋白质的表示拼接
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1),  # 二分类输出
            nn.Sigmoid()
        )
        
    def forward(self, seq_a, seq_b):
        # 获取蛋白质A的表示
        emb_a = self._get_protein_embedding(seq_a)
        # 获取蛋白质B的表示  
        emb_b = self._get_protein_embedding(seq_b)
        
        # 拼接两个表示并分类
        combined = torch.cat([emb_a, emb_b], dim=-1)
        output = self.classifier(combined)
        return output.squeeze()
    
    def _get_protein_embedding(self, sequence):
        """提取蛋白质序列的嵌入表示"""
        inputs = self.tokenizer(
            sequence, 
            return_tensors="pt", 
            padding=True, 
            truncation=True,
            max_length=1024
        )
        
        # 将输入移动到模型所在的设备
        inputs = {k: v.to(next(self.esm.parameters()).device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.esm(**inputs)
        
        # 使用平均池化获取整个序列的表示
        sequence_embedding = outputs.last_hidden_state.mean(dim=1)
        return sequence_embedding


In [None]:
model = PPIModel()
    
# 示例蛋白质序列
protein_a = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAED"
protein_b = "GIVEQCCTSICSLYQLENYCN"

# 预测互作概率
with torch.no_grad():
    probability = model(protein_a, protein_b)
    print(f"相互作用概率: {probability.item():.4f}")

In [None]:
# 加载中等规模的ESM-2模型
model_name = "facebook/esm2_t33_650M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)
model = EsmModel.from_pretrained(model_name)

In [None]:
# 准备蛋白质序列
sequence = "MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGERGFFYTPKTRREAED"

# 编码和推理
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True)


In [None]:
inputs

In [None]:
with torch.no_grad():
    outputs = model(**inputs)
    
print("安装成功！模型输出形状:", outputs.last_hidden_state.shape)