In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("esm2_t30_150m_ur50d")
model = AutoModelForMaskedLM.from_pretrained("esm2_t30_150m_ur50d")

In [None]:
# 数据地址： https://www.uniprot.org/uniprotkb

def read_fasta_file(file_path):
    """
    读取FASTA文件并返回序列信息
    
    参数:
    file_path (str): FASTA文件的路径
    
    返回:
    list: 包含每个序列信息的字典列表，每个字典包含'id'和'sequence'键
    """
    sequences = []
    current_id = None
    current_sequence = []
    
    try:
        with open(file_path, 'r') as file:
            for line in file:
                line = line.strip()
                if not line:
                    continue
                if line.startswith('>'):
                    # 如果遇到新的序列标识符
                    if current_id is not None:
                        # 保存上一个序列
                        sequences.append({
                            'id': current_id,
                            'sequence': ''.join(current_sequence)
                        })
                    # 提取新的序列标识符
                    current_id = line[1:]
                    current_sequence = []
                else:
                    # 累加序列行
                    current_sequence.append(line)
        
        # 保存最后一个序列
        if current_id is not None:
            sequences.append({
                'id': current_id,
                'sequence': ''.join(current_sequence)
            })
        
        return sequences
    
    except FileNotFoundError:
        print(f"错误: 文件 '{file_path}' 未找到")
        return []
    except Exception as e:
        print(f"错误: 读取文件时发生异常: {e}")
        return []

# 使用示例
fasta_file = "uniprot_sprot.fasta"
sequences = read_fasta_file(fasta_file)
# sequences[0]['sequence'], sequences[0]['id']



In [None]:
masked_sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYN<mask>VATPRGYVLAGG"
inputs = tokenizer(masked_sequence, return_tensors="pt")
with torch.no_grad():
    outputs = model(
        **inputs,
        output_hidden_states=True,  # 返回所有层的隐藏状态
        output_attentions=True,     # 返回注意力权重
    )
logits = outputs.logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]  # 找到<mask>的位置
predicted_token_id = logits[0, mask_token_index].argmax(-1)  # 取分数最高的氨基酸
predicted_aa = tokenizer.decode(predicted_token_id)  # 解码为氨基酸符号
print(f"Predicted amino acid: {predicted_aa}")

In [None]:
def get_embedding(seq):
    inputs = tokenizer(seq, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    return outputs.hidden_states[-1][0, 0, :]  # 取[CLS]嵌入

emb1 = get_embedding(seq1)
emb2 = get_embedding(seq2)
similarity = torch.cosine_similarity(emb1, emb2, dim=0)
print(f"Embedding Similarity: {similarity.item()}")

In [None]:
# 3. 蛋白质序列的embedding提取 (Feature Extraction)
inputs = tokenizer(masked_sequence, return_tensors="pt")
with torch.no_grad():
    outputs = model(
        **inputs,
        output_hidden_states=True,  # 返回所有层的隐藏状态
        output_attentions=True,     # 返回注意力权重
    )

last_hidden_states = outputs.hidden_states[-1]  
print(last_hidden_states.shape)
cls_embedding = last_hidden_states[0, 0, :] #最后一层的输出
# print(outputs)
# print(last_hidden_states)
print(cls_embedding.shape)