## This notebook takes a FASTA file and saves them to a dataframe with embeddings - in embedding_df.csv

In [8]:
import pandas as pd
import numpy as np
from ete3 import Tree
import pandas as pd
import os
from transformers import AutoModel, AutoTokenizer
import torch
import random
from Bio import SeqIO, AlignIO
from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq

In [9]:
def get_sequence_df(
    *fasta_paths,
    drop_duplicates=True,
    alignment=False,
    ancestor=False,
    alphabet="ABCDEFGHIJKLMNOPQRSTUVWXYZ-",
):
    seq_list = []
    duplicates = {}

    cols = [
        "info",
        "truncated_info",
        "extracted_id",
        "extracted_name",
        "sequence",
        "original_fasta",
    ]

    if alignment or ancestor:
        print("This is an alignment")
        cols.append("original_alignment")
        cols.append("Sequence_aligned")

    # if ancestor:
    #     cols.append("Sequence_aligned")

    for fasta_path in fasta_paths:
        # Load FASTA file
        # seqs = sequence.readFastaFile(fasta_path, alpha)

        if alignment:
            seqs = AlignIO.parse(open(fasta_path), format="fasta")

        else:
            seqs = SeqIO.parse(open(fasta_path), format="fasta")

        # Add to annotation file
        for seq in seqs:
            if alignment == False:
                if seq.name in duplicates:
                    print(
                        f"DUPLICATE:{seq.name} is in {duplicates[seq.name]} and {fasta_path}\n"
                    )
                else:
                    duplicates[seq.name] = fasta_path

                curr_seq = [
                    seq.id,
                    seq.id.split(" ")[0],
                    seq.id.split("|")[1]
                    if len(seq.id.split("|")) > 1
                    else seq.id.split(" ")[0],
                    seq.id.split("|")[-1],
                    "".join(str(seq.seq).replace("-", ""))
                    if len(seq.seq) > 0
                    else None,
                    fasta_path,
                ]

                seq_list.append(curr_seq)

            elif alignment:
                for aligned_seq in seq:
                    curr_seq = [
                        aligned_seq.id,
                        aligned_seq.id.split(" ")[0],
                        aligned_seq.id.split("|")[1]
                        if len(aligned_seq.id.split("|")) > 1
                        else aligned_seq.id.split(" ")[0],
                        aligned_seq.id.split("|")[-1],
                        "".join(str(aligned_seq.seq).replace("-", ""))
                        if len(aligned_seq.seq) > 0
                        else None,
                        None,
                        fasta_path,
                        "".join(aligned_seq.seq),
                    ]
                    seq_list.append(curr_seq)

            # if ancestor:
            #     curr_seq.append("".join(aligned_seq.seq))

    df = pd.DataFrame(seq_list, columns=cols)

    if drop_duplicates:
        df = df.drop_duplicates(subset="info", keep="first")

    # Drop the sequence column if there are no sequences (i.e. if we just added a list of identifiers)
    nan_value = float("NaN")

    # df.replace("", nan_value, inplace=True)

    df.dropna(how="all", axis=1, inplace=True)

    return df

In [10]:


def calculate_embeddings(sequence, model, tokenizer, model_type):
    """Calculate various embeddings for a given sequence."""
    inputs = tokenizer(
        " ".join(sequence), return_tensors="pt", padding=True, truncation=True
    )
    with torch.no_grad():
        if model_type == "protbert":
            outputs = model(**inputs)
        elif model_type == "t5":
            outputs = model(**inputs.input_ids)
        else:
            raise ValueError(f"Unsupported model type: {model_type}")

    embeddings = outputs.last_hidden_state

    # Mean pooling
    mean_embedding = embeddings.mean(dim=1).squeeze().numpy()

    # CLS token pooling
    cls_embedding = embeddings[:, 0].squeeze().numpy()

    # Max pooling
    max_embedding = embeddings.max(dim=1).values.squeeze().numpy()

    # Weighted pooling
    weights = torch.linspace(0.1, 1.0, embeddings.size(1), device=embeddings.device)
    weights = weights.unsqueeze(0).unsqueeze(
        -1
    )  # Add extra dimensions for broadcasting
    weighted_embedding = (embeddings * weights).mean(dim=1).squeeze().numpy()

    return {
        f"{model_type}_mean_embedding": mean_embedding,
        f"{model_type}_cls_embedding": cls_embedding,
        f"{model_type}_max_embedding": max_embedding,
        f"{model_type}_weighted_embedding": weighted_embedding,
    }


def process_and_store_embeddings(df, model_name, embedding_df_path, model_type):
    """Process and store multiple types of embeddings for sequences in the DataFrame."""
    model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Load existing embeddings if they exist
    if os.path.exists(embedding_df_path):
        embedding_df = pd.read_pickle(embedding_df_path)
    else:
        embedding_df = pd.DataFrame(columns=["info", "sequence", "model_name"])
        

    for idx, row in df.iterrows():
        info = row["info"]
        sequence = row["sequence"]

        existing_row = embedding_df[
            (embedding_df["info"] == info)
            & (embedding_df["model_name"] == model_name)
        ]

        # if not existing_row.empty:
        
        if not existing_row.empty and f"{model_type}_mean_embedding" in existing_row.columns:
            
               # Ensure the specific column has data
            if not existing_row[f"{model_type}_mean_embedding"].empty:
                continue  # Skip if embeddings for this sequence already exist


        try:
            embeddings = calculate_embeddings(sequence, model, tokenizer, model_type)
            new_row = {
                "info": info,
                "sequence": sequence,
                "model_name": model_name,
                **embeddings,
            }
            embedding_df = pd.concat(
                [embedding_df, pd.DataFrame([new_row])], ignore_index=True
            )

        except Exception as e:
            print(f"Failed to process sequence {sequence} with error: {e}")

    # Save embedding_df with full embeddings
    embedding_df.to_pickle(embedding_df_path)
    merged_df = pd.merge(df, embedding_df, on=['info', 'sequence'], how='left')

    return merged_df


In [12]:
df = get_sequence_df("./NR_MSA_ancestors.fa", alignment=True)
df_extant = get_sequence_df("../NR_MSA.fasta", alignment=True)



This is an alignment
This is an alignment


In [13]:
# df.tail()
df_extant.tail()

Unnamed: 0,info,truncated_info,extracted_id,extracted_name,sequence,original_alignment,Sequence_aligned
16240,Podarcis_ESRRA_40-2641-225,Podarcis_ESRRA_40-2641-225,Podarcis_ESRRA_40-2641-225,Podarcis_ESRRA_40-2641-225,NTMVSHLLVAEPEKLYAMPDPALPDSPAKAASTLCDLADREIVVII...,../NR_MSA.fasta,----------------------------------------------...
16241,Petromyzon_ESRRB_82-3051-224,Petromyzon_ESRRB_82-3051-224,Petromyzon_ESRRB_82-3051-224,Petromyzon_ESRRB_82-3051-224,NKMVSQLLVVEPDRLFAMAGPGAAECDVTALTTLCDLADRELVLII...,../NR_MSA.fasta,----------------------------------------------...
16242,Ochotona_ESRRB_111-3341-224,Ochotona_ESRRB_111-3341-224,Ochotona_ESRRB_111-3341-224,Ochotona_ESRRB_111-3341-224,TKIVSCLMVAEPNNLQAMPPAGIPEADIKALATLCDLADRELVVII...,../NR_MSA.fasta,----------------------------------------------...
16243,Rana_ESRRB_130-3531-224,Rana_ESRRB_130-3531-224,Rana_ESRRB_130-3531-224,Rana_ESRRB_130-3531-224,TRIVSHLLLAEPEKIFAMADPAGPDSDIKVLSTLVDLTDRELVMTI...,../NR_MSA.fasta,----------------------------------------------...
16244,Pipistrellus_ESRRB_147-371-224,Pipistrellus_ESRRB_147-371-224,Pipistrellus_ESRRB_147-371-224,Pipistrellus_ESRRB_147-371-224,TKIVSYLLVAEPNKPSARPPPGMPESDIKALTTLCDLADQELVSII...,../NR_MSA.fasta,----------------------------------------------...


In [14]:
bert_model_name = "yarongef/DistilProtBert"
bert_embedding_df_path = "./protbert_embeddings_NR.pkl"

embedding_df = process_and_store_embeddings(df, bert_model_name, "./embdding_df.csv", model_type='protbert')
# embedding_extant_df = process_and_store_embeddings(df_extant, bert_model_name, "./embedding_extant_df.csv", model_type='protbert')



Some weights of BertModel were not initialized from the model checkpoint at yarongef/DistilProtBert and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertModel were not initialized from the model checkpoint at yarongef/DistilProtBert and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


KeyboardInterrupt: 

In [None]:
# embedding_df.head()
embedding_extant_df.head()

Unnamed: 0,info,truncated_info,extracted_id,extracted_name,sequence,original_alignment,Sequence_aligned,model_name,protbert_mean_embedding,protbert_cls_embedding,protbert_max_embedding,protbert_weighted_embedding
0,N0,N0,N0,N0,TCAKLEPEDADENIDVTGNEPERTSTEYQMSPYPSASPESVYETSA...,./NR_MSA_ancestors.fa,----------------------------------------------...,yarongef/DistilProtBert,"[-0.17380047, 0.037478294, 0.049007155, -0.013...","[0.037427068, 0.075575374, -0.036185384, -0.23...","[0.12167235, 0.28130862, 0.21617526, 0.2462811...","[-0.10807977, 0.01987329, 0.028579468, -0.0042..."
1,N1,N1,N1,N1,TCAKLEPEDADENIDVTGNEPERTSTEYQMSPYPSASPESVYETSA...,./NR_MSA_ancestors.fa,----------------------------------------------...,yarongef/DistilProtBert,"[-0.17380047, 0.037478294, 0.049007155, -0.013...","[0.037427068, 0.075575374, -0.036185384, -0.23...","[0.12167235, 0.28130862, 0.21617526, 0.2462811...","[-0.10807977, 0.01987329, 0.028579468, -0.0042..."
2,N2,N2,N2,N2,TCAKLEPEDADENIDVTGNEPERTSTEYQMSPYPSASPESVYETSA...,./NR_MSA_ancestors.fa,----------------------------------------------...,yarongef/DistilProtBert,"[-0.17380047, 0.037478294, 0.049007155, -0.013...","[0.037427068, 0.075575374, -0.036185384, -0.23...","[0.12167235, 0.28130862, 0.21617526, 0.2462811...","[-0.10807977, 0.01987329, 0.028579468, -0.0042..."
3,N3,N3,N3,N3,TCAKLEPEDADENIDVTGNEPERTSTEYPMSPYPSASPESVYETSA...,./NR_MSA_ancestors.fa,----------------------------------------------...,yarongef/DistilProtBert,"[-0.17028484, 0.036823433, 0.049609266, -0.010...","[0.04245688, 0.08160813, -0.030255817, -0.2466...","[0.11976309, 0.27247024, 0.21556818, 0.2510118...","[-0.10674974, 0.019780792, 0.028877182, -0.003..."
4,N4,N4,N4,N4,TCAKLEPEDADENIDVTGNEPERTSTEYPMSPYPSASPEGVYETSA...,./NR_MSA_ancestors.fa,----------------------------------------------...,yarongef/DistilProtBert,"[-0.1697389, 0.035590477, 0.050194185, -0.0116...","[0.040663913, 0.08406882, -0.034715842, -0.248...","[0.122089, 0.27262154, 0.21623468, 0.24625377,...","[-0.10627797, 0.01953295, 0.029021893, -0.0033..."
