In [None]:
import torch
import esm
import pandas as pd
import numpy as np


In [None]:
df = pd.read_csv("Stratified_data.csv")
df["ID"] = df.index
id_column = "ID"
seq_column = "Sequence"
batch_size = 32
device = torch.device("cpu")

In [None]:
sequences = []
for idx, row in df.iterrows():
    name = str(row[id_column])
    seq = str(row[seq_column])
    sequences.append((name, seq))

In [None]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()
model = model.to(device)


In [None]:
all_embeddings = []

for i in range(0, len(sequences), batch_size):
    batch = sequences[i:i+batch_size]
    labels, strs, tokens = batch_converter(batch)
    tokens = tokens.to(device)

    with torch.no_grad():
        results = model(tokens, repr_layers=[31])
        token_reps = results["representations"][31]

    # Mean-pool per sequence
    for j, (_, seq) in enumerate(batch):
        emb = token_reps[j, 1:len(seq)+1].mean(0).cpu().numpy()
        all_embeddings.append(emb)

    print(f"Processed batch {i//batch_size + 1}/{(len(sequences)+batch_size-1)//batch_size}")

In [None]:
all_embeddings = np.stack(all_embeddings) 
df["esm2_embedding"] = list(all_embeddings)
df.to_csv("sequences_with_embeddings.csv", index=False)
