In [None]:
!pip install fair-esm
import torch
import esm
import pandas as pd
from tqdm import tqdm
from google.colab import files

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


In [None]:
# Load smaller ESM-2 model that will work better on CPU
model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # Keep on CPU

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t30_150M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t30_150M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t30_150M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t30_150M_UR50D-contact-regression.pt


ESM2(
  (embed_tokens): Embedding(33, 640, padding_idx=1)
  (layers): ModuleList(
    (0-29): 30 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=640, out_features=640, bias=True)
        (v_proj): Linear(in_features=640, out_features=640, bias=True)
        (q_proj): Linear(in_features=640, out_features=640, bias=True)
        (out_proj): Linear(in_features=640, out_features=640, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=640, out_features=2560, bias=True)
      (fc2): Linear(in_features=2560, out_features=640, bias=True)
      (final_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=600, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((640,), eps=1e-05, elementw

In [None]:
# Get the number of layers in the model
num_layers = len(model.layers)
print(f"Model has {num_layers} layers, using layer {num_layers} for embeddings")

Model has 30 layers, using layer 30 for embeddings


In [None]:
def read_and_process_in_batches(file_path, batch_size=2):  # Small batch size for CPU

    """
    Reads a FASTA file, processes sequences in batches using a protein language model
    (e.g., ESM), and returns sequence IDs and their corresponding [CLS] embeddings.

    This function is designed for efficient processing of large protein sequence files,
    handling batches on CPU (or GPU if enabled), and gracefully skipping problematic sequences.

    Workflow:
    1. Reads sequences from a FASTA file.
    2. Splits sequences into batches of specified size.
    3. For each batch:
       - Uses a pre-trained model (e.g., ESM) to extract embeddings.
       - Retrieves the [CLS] token embedding (typically at position 0).
    4. Collects and returns embeddings and corresponding sequence IDs.

    Parameters:
    - file_path (str): Path to the input FASTA file.
    - batch_size (int): Number of sequences per batch (default: 2 for CPU compatibility).

    Returns:
    - seq_ids (list of str): List of sequence IDs (FASTA headers) corresponding to the embeddings.
    - cls_embeddings (list of np.ndarray): List of [CLS] embeddings for each sequence.
    """

    # Read sequences
    sequences = []
    with open(file_path, 'r') as f:
        seq_id = ""
        seq = ""
        for line in f:
            if line.startswith(">"):
                if seq_id:
                    sequences.append((seq_id, seq))
                seq_id = line.strip()
                seq = ""
            else:
                seq += line.strip()
        if seq_id:
            sequences.append((seq_id, seq))

    # Process in batches
    cls_embeddings = []
    seq_ids = []

    for i in tqdm(range(0, len(sequences), batch_size)):
        batch = sequences[i:i+batch_size]
        try:
            batch_labels, batch_strs, batch_tokens = batch_converter(batch)

            with torch.no_grad():
                results = model(batch_tokens, repr_layers=[num_layers], return_contacts=False)
                token_embeddings = results["representations"][num_layers]

            cls_embeddings.extend(token_embeddings[:, 0, :].numpy())
            seq_ids.extend([seq_id for seq_id, seq in batch])
        except Exception as e:
            print(f"Error processing batch {i//batch_size}: {str(e)}")
            continue

    return seq_ids, cls_embeddings


In [None]:
# =========================================================
# Example: Compute ESM-2 CLS embeddings for cleaned DMC1 of arthropods
#
# 1️⃣ Upload the .faa file containing cleaned DMC1 protein sequences:
#    (Example filename: "cleaned_DMC1_A.faa")
# from google.colab import files
# uploaded = files.upload()
#
# 2️⃣ Process sequences in batches and compute embeddings:
#     seq_ids, cls_embeddings = read_and_process_in_batches("cleaned_DMC1_A.faa")
#
# 3️⃣ Convert results into a DataFrame and save as CSV:
#     df = pd.DataFrame(cls_embeddings, index=seq_ids)
#     df.to_csv("cleaned_DMC1_A_CLS_embeddings.csv")
#
# 4️⃣ Download the final CLS embeddings CSV file:
#     files.download("cleaned_DMC1_A_CLS_embeddings.csv")
# =========================================================

Saving cleaned_DMC1_A.faa to cleaned_DMC1_A.faa


100%|██████████| 61/61 [05:05<00:00,  5.01s/it]


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>