In [None]:
!pip install esm
!pip install biopython

In [None]:
# Import libraries
# Standard libraries
import pandas as pd
import numpy as np

from Bio import SeqIO
import matplotlib.pyplot as plt

# ML libraries
import torch
from huggingface_hub import login

# ESMC and batching libraries
import esm
from esm.sdk.api import (
    ESM3InferenceClient, 
    ESMProtein, 
    GenerationConfig, 
    ESMProteinError, 
    LogitsConfig, 
    LogitsOutput, 
    ProteinType
)
from concurrent.futures import ThreadPoolExecutor
from typing import Sequence

In [None]:
## Load ESM-C models from forge
from getpass import getpass
token = getpass("Token from Forge console: ")

from esm.sdk import client
model: ESM3InferenceClient = esm.sdk.client(
    model="esmc-6b-2024-12",
    url="https://forge.evolutionaryscale.ai",
    token=token
)

In [None]:
## Read all sequences as fasta files
def read_sequences(
    fasta_path: str) -> pd.DataFrame:
    fasta_df = pd.DataFrame(columns=["description", "sequence"])
    for record in SeqIO.parse(fasta_path, "fasta"):
        fasta_df = pd.concat(
            [fasta_df, pd.DataFrame(
                [[record.id, str(record.seq)]], 
                columns=["description", "sequence"])], 
            ignore_index=True
        )
    return fasta_df

In [None]:
## Input sequences and output logits and embeddings
def embed_sequence(
    model: ESM3InferenceClient, sequence: str) -> LogitsOutput:
    protein = ESMProtein(sequence=sequence)
    protein_tensor = model.encode(protein)
    logits_output = model.logits(
       protein_tensor, LogitsConfig(
           sequence=True, 
           return_embeddings=True,
           return_hidden_states=False
           )
    )
    return logits_output

## Batch embed sequences
def batch_embed(
    model: ESM3InferenceClient, inputs: Sequence[ProteinType]) -> Sequence[LogitsOutput]:
    with ThreadPoolExecutor() as executor:
        futures = [
            executor.submit(embed_sequence, model, protein) for protein in inputs
        ]
        results = []
        for future in futures:
            try:
                results.append(future.result())
            except Exception as e:
                results.append(ESMProteinError(500, str(e)))
    return results

In [None]:
## Input sequences and output encoded sequence
def encode_sequence(
    model: ESM3InferenceClient, sequence: str) -> torch.Tensor:
    protein = ESMProtein(sequence=sequence)
    protein_tensor = model.encode(protein)
    return protein_tensor

## Batch encode sequences
def batch_encode(
    model: ESM3InferenceClient, inputs: Sequence[ProteinType]) -> Sequence[torch.Tensor]:
    with ThreadPoolExecutor() as executor:
        futures = [
            executor.submit(encode_sequence, model, protein) for protein in inputs
        ]
        results = []
        for future in futures:
            try:
                results.append(future.result())
            except Exception as e:
                print(e)
    return results

In [None]:
## Run the program!
# Import sequences and filter
fasta_path = '/home/azureuser/cloudfiles/code/Users/jc62/projects/direct_sequence_analysis/data/antiphage_sequences.fasta'
fasta_df = read_sequences(fasta_path)
# Filter sequences for processing
fasta_short = fasta_df.sample(500, random_state=31718)
fasta_really_short = fasta_df[fasta_df['sequence'].str.len() < 400]
print(f"{fasta_short['sequence'].str.len().describe()}")

# Carry out pLLM operation
outputs = batch_embed(model, fasta_really_short["sequence"].tolist())
# or
output_tensors = batch_encode(model, fasta_short["sequence"].tolist())

In [None]:
## Save the output
torch.save(outputs, '/home/azureuser/cloudfiles/code/Users/jc62/projects/direct_sequence_analysis/data/esmc_embed_batch_500.pt')