In [1]:
!pip install fair-esm
import torch
import esm
import pandas as pd
from tqdm import tqdm
from google.colab import files

Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0


In [2]:
# Load smaller ESM-2 model that will work better on CPU
model, alphabet = esm.pretrained.esm2_t30_150M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # Keep on CPU

Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t30_150M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t30_150M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t30_150M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t30_150M_UR50D-contact-regression.pt


ESM2(
  (embed_tokens): Embedding(33, 640, padding_idx=1)
  (layers): ModuleList(
    (0-29): 30 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=640, out_features=640, bias=True)
        (v_proj): Linear(in_features=640, out_features=640, bias=True)
        (q_proj): Linear(in_features=640, out_features=640, bias=True)
        (out_proj): Linear(in_features=640, out_features=640, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=640, out_features=2560, bias=True)
      (fc2): Linear(in_features=2560, out_features=640, bias=True)
      (final_layer_norm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=600, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((640,), eps=1e-05, elementw

In [3]:
# Get the number of layers in the model
num_layers = len(model.layers)
print(f"Model has {num_layers} layers, using layer {num_layers} for embeddings")

Model has 30 layers, using layer 30 for embeddings


In [12]:
from multiprocessing import Pool, cpu_count
import torch
from tqdm import tqdm

def process_batch(batch):
    """
    Process a single batch of protein sequences through the ESM model to obtain [CLS] embeddings.

    Args:
        batch (list of tuples): Each tuple is (sequence_id, sequence_string).

    Returns:
        tuple: (list of sequence_ids, list of CLS embedding arrays)
               If an error occurs, returns None.
    """

    try:
        batch_labels, batch_strs, batch_tokens = batch_converter(batch)

        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[num_layers], return_contacts=False)
            token_embeddings = results["representations"][num_layers]

        return (
            [seq_id for seq_id, seq in batch],
            token_embeddings[:, 0, :].numpy()
        )
    except Exception as e:
        print(f"Error processing batch: {str(e)}")
        return None

def read_and_process_in_batches(file_path, batch_size=2, num_processes=None):
    """
    Read protein sequences from a FASTA file, batch them, and extract CLS embeddings using ESM.

    Args:
        file_path (str): Path to input FASTA file.
        batch_size (int, optional): Number of sequences per batch (default=2).
        num_processes (int, optional): Number of CPU cores for parallel processing.
                                       If None, uses available cores based on data size.

    Returns:
        tuple: (list of sequence IDs, list of CLS embeddings)
               - sequence IDs: list of strings, each corresponding to a sequence.
               - CLS embeddings: list of numpy arrays, each of shape (embedding_dim,).
    """

    sequences = []
    with open(file_path, 'r') as f:
        seq_id = ""
        seq = ""
        for line in f:
            if line.startswith(">"):
                if seq_id:
                    sequences.append((seq_id, seq))
                seq_id = line.strip()
                seq = ""
            else:
                seq += line.strip()
        if seq_id:
            sequences.append((seq_id, seq))

    # Determine number of processes to use
    if num_processes is None:
        num_processes = min(cpu_count(), len(sequences) // batch_size or 1)

    # Prepare batches
    batches = [sequences[i:i+batch_size] for i in range(0, len(sequences), batch_size)]

    # Process batches in parallel
    cls_embeddings = []
    seq_ids = []

    with Pool(num_processes) as pool:
        results = list(tqdm(pool.imap(process_batch, batches), total=len(batches)))

    # Collect results
    for result in results:
        if result is not None:
            batch_seq_ids, batch_embeddings = result
            seq_ids.extend(batch_seq_ids)
            cls_embeddings.extend(batch_embeddings)

    return seq_ids, cls_embeddings

In [22]:
from google.colab import files
import pandas as pd

# Upload the file
uploaded = files.upload()
file_name = list(uploaded.keys())[0]

# Process sequences
seq_ids, cls_embeddings = read_and_process_in_batches(file_name, batch_size=2)

# Create DataFrame and save
df = pd.DataFrame(cls_embeddings, index=seq_ids)
output_filename = file_name.replace('.faa', '_CLS_embeddings.csv').replace('.fasta', '_CLS_embeddings.csv')
df.to_csv(output_filename)

# Download the results
files.download(output_filename)

Saving cleaned_MYO_animals_top3.faa to cleaned_MYO_animals_top3.faa


100%|██████████| 41/41 [44:27<00:00, 65.07s/it]


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>