In [1]:
!pip install biopython
!pip install fair-esm

import torch
import pandas as pd
import numpy as np
from Bio import SeqIO
import esm
import gc
# CUDA
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('GPU available')
else:
    device = torch.device('cpu')
    print('CPU available')

Collecting biopython
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/3.3 MB[0m [31m35.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m49.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.85
Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-

In [2]:
# 加载模型和字母表
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()
model = model.to(device)
model.eval()

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm1b_t33_650M_UR50S.pt" to /root/.cache/torch/hub/checkpoints/esm1b_t33_650M_UR50S.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm1b_t33_650M_UR50S-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm1b_t33_650M_UR50S-contact-regression.pt


ProteinBertModel(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (embed_positions): LearnedPositionalEmbedding(1026, 1280, padding_idx=1)


In [3]:
def esm_to_feature(tuple_seq_list, num_layer, device):
    i = 1
    embeddings = []
    for tuple_seq in tuple_seq_list:
        print(f"{i}/{len(tuple_seq_list)}")
        i += 1
        print('长度: ', len(tuple_seq[1]))
        print(tuple_seq[1]) # tuple_seq是一个tuple
        # 将蛋白质序列列表加载到批处理转换器中，且批处理转换器接收的是一个list变量
        batch_labels, batch_strs, batch_tokens = batch_converter([tuple_seq])
        batch_tokens = batch_tokens.to(device)  # 将批处理令牌移动到所选设备上

        # 提取每个残基的表示
        with torch.no_grad():
            # repr_layers 等于模型层数
            results = model(batch_tokens, repr_layers=[num_layer], return_contacts=False)
        token_representations = results["representations"][num_layer]

        batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)  # 记录每条序列的真实长度+2 (开始和结束标记符)
        # 通过平均值生成每个序列的表示
        # 注意：令牌0始终是序列开始令牌，所以第一个残基的令牌是1。
        sequence_representation = token_representations[0, 1:batch_lens[0] - 1].mean(0)
        embeddings.append(sequence_representation)

        del batch_tokens, results, token_representations, batch_labels, batch_strs
        torch.cuda.empty_cache()  # 清理 GPU 缓存
        gc.collect()  # 释放 Python 内存

    return torch.stack(embeddings)

In [4]:
def format(file_path):
    sequences = []
    labels = []
    for record in SeqIO.parse(file_path, "fasta"):
        seq = str(record.seq)
        if len(seq) > 1022:
            seq = seq[:1022]
        sequences.append(tuple([_, seq]))
        labels.append(0)
    return sequences, labels

In [5]:
file_path = r'merged_train.fasta'
tuple_seq_list, _ = format(file_path)
embeddings = esm_to_feature(tuple_seq_list, num_layer=33, device=device)
torch.save(embeddings, r'merged_train_esm1b.pt')
print('done')

[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
MQSWSRVYCSLAKRGHFNRISHGLQGLSAVPLRTYADQPIDADVTVIGSGPGGYVAAIKAAQLGFKTVCIEKNETLGGTCLNVGCIPSKALLNNSHYYHMAHGKDFASRGIEMSEVRLNLDKMMEQKSTAVKALTGGIAHLFKQNKVVHVNGYGKITGKNQVTATKADGGTQVIDTKNILIATGSEVTPFPGITIDEDTIVSSTGALSLKKVPEKMVVIGAGVIGVELGSVWQRLGADVTAVEFLGHVGGVGIDMEISKNFQRILQKQGFKFKLNTKVTGATKKSDGKIDVSIEAASGGKAEVITCDVLLVCIGRRPFTKNLGLEELGIELDPRGRIPVNTRFQTKIPNIYAIGDVVAGPMLAHKAEDEGIICVEGMAGGAVHIDYNCVPSVIYTHPEVAWVGKSEEQLKEEGIEYKVGKFPFAANSRAKTNADTDGMVKILGQKSTDRVLGAHILGPGAGEMVNEAALALEYGASCEDIARVCHAHPTLSEAFREANLAASFGKSINF
398/2063
长度:  1022
MGDAEGEDEVQFLRTDDEVVLQCSATVLKEQLKLCLAAEGFGNRLCFLEPTSNAQNVPPDLAICCFVLEQSLSVRALQEMLANTVEAGVESSQGGGHRTLLYGHAILLRHAHSRMYLSCLTTSRSMTDKLAFDVGLQEDATGEACWWTMHPASKQRSEGEKVRVGDDIILVSVSSERYLHLSTASGELQVDASFMQTLWNMNPICSRCEEGFVTGGHVLRLFHGHMDECLTISPADSDDQRRLVYYEGGAVCTHARSLWRLEPLRISWSGSHLRWGQPLRVRHVTTGQYLALTEDQGLVVVDASKAHTKATSFCFRISKEKLDVAPKRDVEGMGPPEIKYGESLCFVQHVASGLWLTYAAPDPKALRLGVLKKKAMLHQEGHMDDALSLTRCQQEESQAARMIHSTNGLYNQFIKSLDSFSGKPRGSGPP

In [None]:
file_path = r'merged_test.fasta'
tuple_seq_list, _ = format(file_path)
embeddings = esm_to_feature(tuple_seq_list, num_layer=33, device=device)
torch.save(embeddings, r'merged_test_esm1b.pt')
print('done')

1/478
长度:  440
MNYSLHLAFVCLSLFTERMCIQGSQFNVEVGRSDKLSLPGFENLTAGYNKFLRPNFGGEPVQIALTLDIASISSISESNMDYTATIYLRQRWMDQRLVFEGNKSFTLDARLVEFLWVPDTYIVESKKSFLHEVTVGNRLIRLFSNGTVLYALRITTTVACNMDLSKYPMDTQTCKLQLESWGYDGNDVEFTWLRGNDSVRGLEHLRLAQYTIERYFTLVTRSQQETGNYTRLVLQFELRRNVLYFILETYVPSTFLVVLSWVSFWISLDSVPARTCIGVTTVLSMTTLMIGSRTSLPNTNCFIKAIDVYLGICFSFVFGALLEYAVAHYSSLQQMAAKDRGTTKEVEEVSITNIINSSISSFKRKISFASIEISSDNVDYSDLTMKTSDKFKFVFREKMGRIVDYFTIQNPSNVDHYSKLLFPLIFMLANVFYWAYYMYF
2/478
长度:  1022
MDAKARNCLLQHREALEKDIKTSYIMDHMISDGFLTISEEEKVRNEPTQQQRAAMLIKMILKKDNDSYVSFYNALLHEGYKDLAALLHDGIPVVSSSSGKDSVSGITSYVRTVLCEGGVPQRPVVFVTRKKLVNAIQQKLSKLKGEPGWVTIHGMAGCGKSVLAAEAVRDHSLLEGCFPGGVHWVSVGKQDKSGLLMKLQNLCTRLDQDESFSQRLPLNIEEAKDRLRILMLRKHPRSLLILDDVWDSWVLKAFDSQCQILLTTRDKSVTDSVMGPKYVVPVESSLGKEKGLEILSLFVNMKKADLPEQAHSIIKECKGSPLVVSLIGALLRDFPNRWEYYLKQLQNKQFKRIRKSSSYDYEALDEAMSISVEMLREDIKDYYTDLSILQKDVKVPTKVLCILWDMETEEVEDILQEFVNKSLLFCDRNGKSFRYYLHDLQVDFLTEKNCSQLQDLHKKIITQFQRYHQPHTLSPDQEDCMYWYNFLAYHMASAKMHKELCALMFSLDWIKAKTELVGPAHLIHEFVE