In [None]:
import torch
import numpy as np
from typing import List, Tuple
import py3Dmol

from esm.pretrained import ESM3_sm_open_v0
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.structure.protein_structure import compute_affine_and_rmsd
from esm.utils.structure.aligner import Aligner
from esm.sdk.api import ESMProtein, GenerationConfig
from esm.tokenization import EsmSequenceTokenizer
from esm.utils.constants.esm3 import SEQUENCE_MASK_TOKEN
from esm.models.esm3 import ESM3
from esm.utils.generation import iterative_sampling_raw
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig, SamplingConfig, SamplingTrackConfig

import pandas as pd
import os
from huggingface_hub import snapshot_download
from pathlib import Path
from tqdm import tqdm
import os
from pathlib import Path
import sys
from transformers import EsmTokenizer, EsmForSequenceClassification
import torch
from peft import PeftModelForSequenceClassification
import seaborn as sns
import pandas as pd
%set_env TOKENIZERS_PARALLELISM=false


In [None]:


def pad_coords(coords, start_pad, end_pad):
    return torch.cat([
        torch.full((start_pad, 37, 3), float('nan')), 
        coords, 
        torch.full((end_pad, 37, 3), float('nan'))
        ], dim=0)
    
def mask_protein_and_sequence(
    sequence: str, 
    coords: torch.Tensor, 
    sequence_fixed_residue_indices: List[int], 
    structure_fixed_residue_indices: List[int],
    mask_fraction: float,
    ) -> Tuple[str, torch.Tensor]:
    
    num_residues = len(sequence)
    num_to_mask = int(num_residues * mask_fraction)
    all_indices = list(range(num_residues))
    mask_indices = np.random.choice(all_indices, num_to_mask, replace=False)
    
    masked_sequence = list(sequence)
    masked_coords = coords.clone()
    
    for idx in mask_indices:
        masked_sequence[idx] = '_'
        masked_coords[idx, :, :] = torch.full_like(masked_coords[idx, :, :], float('nan'))
        
    for idx in sequence_fixed_residue_indices:
        masked_sequence[idx] = sequence[idx]
        
    for idx in structure_fixed_residue_indices:
        masked_coords[idx, :, :] = coords[idx, :, :]
    
    return ''.join(masked_sequence), masked_coords

def sequence_diff(ref, query):
    try:
        diff = []
        for i in range(len(ref)):
            if ref[i] != query[i]:
                diff.append(f"{ref[i]}{i}{query[i]}")
        return '/'.join(diff)
    except:
        return ''


# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the ESM3 model
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cuda") 
tokenizer = EsmSequenceTokenizer()


# set up saprot models
yfp_adapter_input = "SaProtHub/Model-EYFP-650M"
base_model_name = "westlake-repl/SaProt_650M_AF2"
fluor_adapter_input = 'SaProtHub/Model-Fluorescence-650M'

yfp_adapter_path = snapshot_download(repo_id=yfp_adapter_input, repo_type="model")
fluor_adapter_path = snapshot_download(repo_id=fluor_adapter_input, repo_type="model")
base_model = EsmForSequenceClassification.from_pretrained(base_model_name, num_labels=1,)
saprot_yfp_model = PeftModelForSequenceClassification.from_pretrained(
    base_model,
    yfp_adapter_path,
)

base_model_fluor = EsmForSequenceClassification.from_pretrained(base_model_name, num_labels=1,)
saprot_fluor_model = PeftModelForSequenceClassification.from_pretrained(
    base_model_fluor,
    fluor_adapter_path,
)

tokenizer = EsmTokenizer.from_pretrained(base_model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
saprot_yfp_model.to(device);
saprot_fluor_model.to(device);



# set up prompts

#eYFP
ref_sequence = 'MVSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFGYGLQCFARYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSYQSALSKDPNEKRDHMVLLEFVTAAGITLGMDELYK'

# Load YFP template
pdb_id = "1YFP"  # Enhanced Yellow Fluorescent Protein
# pdb_id = "1QY3"
chain_id = "A"
ref_chain = ProteinChain.from_rcsb(pdb_id, chain_id)

ref_pdb_start = ref_chain.residue_index[0] # pad this much at the front
ref_pdb_end = ref_chain.residue_index[-1]
end_padding = len(ref_sequence) - ref_pdb_end + 2

coords = torch.tensor(ref_chain.atom37_positions)
ref_sequence_beginning = ref_sequence[:ref_pdb_start]
ref_sequence_ending = ref_sequence[ref_pdb_end:]
padded_coords = pad_coords(coords, ref_pdb_start, end_padding)
padded_ref_chain = ProteinChain.from_atom37(padded_coords)


# fixed positions
positions = [1, 62, 65, 66, 67, 96, 222]
sequence_fixed_indices = [position for position in positions]
structure_fixed_indices = list(range(58, 72)) + [96, 222]


# masking config
mask_fraction = 0.025  # Define the fraction of non-preserved residues to mask
num_to_mask = int(len(ref_sequence) * mask_fraction)


sequence_prompt, structure_prompt = mask_protein_and_sequence(
    ref_sequence, 
    padded_coords, 
    sequence_fixed_indices, 
    structure_fixed_indices,
    mask_fraction
    )

prompt = ESMProtein(
    sequence=sequence_prompt, 
    coordinates=structure_prompt
    )



# Evaluation Metrics

def calculate_sequence_identity(seq1: str, seq2: str) -> float:
    assert len(seq1) == len(seq2), "Sequences must be of equal length"
    identical = sum(a == b for a, b in zip(seq1, seq2))
    return identical / len(seq1)

def calculate_rmsd(coords1: np.ndarray, coords2: np.ndarray) -> float:
    diff = coords1 - coords2
    return np.sqrt(np.mean(np.sum(diff**2, axis=1)))

def calculate_template_rmsd(variant: ProteinChain, template: ProteinChain, residues: List[int]) -> float:
    aligner = Aligner(variant, template)
    aligned_variant = aligner.apply(variant)
    return aligner.rmsd

def calculate_pseudo_perplexity(model: ESM3_sm_open_v0, sequence: str) -> float:
    tokenizer = EsmSequenceTokenizer()
    tokens = torch.tensor([tokenizer.encode(sequence)]).to(device)
    with torch.no_grad():
        output = model(sequence_tokens=tokens)
    log_probs = torch.log_softmax(output.sequence_logits, dim=-1)
    token_log_probs = log_probs[0, torch.arange(len(sequence)), tokens[0, 1:-1]]
    return torch.exp(-token_log_probs.mean()).item()

def calculate_n_gram_score(sequence: str, n: int = 3) -> float:
    # This is a simplified version. For a full implementation, you'd need a background distribution.
    from collections import Counter
    ngrams = [sequence[i:i+n] for i in range(len(sequence)-n+1)]
    counts = Counter(ngrams)
    return -sum(count * np.log(count/len(ngrams)) for count in counts.values()) / len(ngrams)

def calculate_pssm_score(sequence: str, pssm: np.ndarray) -> float:
    # This is a placeholder. You'd need to implement or import a proper PSSM for YFP.
    return 0.0

def calculate_n_terminus_coil_count(chain: ProteinChain, n: int = 12) -> int:
    ss = chain.dssp()[:n]
    return sum(1 for s in ss if s in ['S', 'T', 'C'])
num_variants = 5000
sequence_gen_configs = [GenerationConfig(
    track="sequence", 
    num_steps=num_to_mask // 1, 
    # temperature=0.1
    ) for _ in range(num_variants)]

structure_gen_configs = [GenerationConfig(
    track="structure", 
    num_steps=num_to_mask // 1, 
    # temperature=0.1
    ) for _ in range(num_variants)]

sequence_prompts = [prompt for _ in range(num_variants)]

# esm generation loop
variants = []
for _ in range(num_variants):
    # Generate YFP variant sequence
    sequence_generation_config = GenerationConfig(
        track="sequence", 
        num_steps=num_to_mask // 1, 
        # temperature=0.1
        )
    variant_sequence = model.generate(prompt, sequence_generation_config)


    # Generate structure for the YFP variant
    structure_generation_config = GenerationConfig(
        track="structure", 
        num_steps=num_to_mask // 1, 
        # temperature=0.1
        )
    variant = model.generate(
        variant_sequence, 
        structure_generation_config
        )

    # Convert ESMProtein to ProteinChain for easier handling
    variant_chain = variant.to_protein_chain()
    
    variants.append(variant_chain)



# Initialize a list to store metrics for each variant
esm_outputs = []

# Name variants and include their AA sequence in the metrics_df
for ii, variant in enumerate(variants):
    # Calculate metrics
    seq_identity = calculate_sequence_identity(variant.sequence, ref_sequence)
    chromophore_rmsd = calculate_template_rmsd(variant, padded_ref_chain, [64, 65, 66])  # 0-indexed
    template_helix_rmsd = calculate_template_rmsd(variant, padded_ref_chain, list(range(57, 71)))  # 0-indexed
    pseudo_perplexity = calculate_pseudo_perplexity(model, variant.sequence)
    n_gram_score = calculate_n_gram_score(variant.sequence)
    # n_terminus_coil_count = calculate_n_terminus_coil_count(variant)
    
    # Append metrics to the list
    esm_outputs.append({
        "name": f'variant_{ii}',
        "sequence": variant.sequence,
        "seq_identity": seq_identity,
        "chromophore_rmsd": chromophore_rmsd,
        "template_helix_rmsd": template_helix_rmsd,
        "pseudo_perplexity": pseudo_perplexity,
        "n_gram_score": n_gram_score,
        # "n_terminus_coil_count": n_terminus_coil_count
    })

# Convert the list of metrics to a DataFrame
esm_outputs_df = pd.DataFrame(esm_outputs)
esm_outputs_df



def AA_to_SA(aa_seq):
    sa_seq = ''
    for aa in aa_seq:
        sa_seq += aa + '#'
    return sa_seq



aa_seqs = [
    {'name': 'eYFP', 'sequence': 'MVSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFGYGLQCFARYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSYQSALSKDPNEKRDHMVLLEFVTAAGITLGMDELYK',},
    {'name': 'Citrine', 'sequence': 'MVSKGEELFT GVVPILVELD GDVNGHKFSV SGEGEGDATY GKLTLKFICT TGKLPVPWPT LVTTFGYGLM CFARYPDHMK QHDFFKSAMP EGYVQERTIF FKDDGNYKTR AEVKFEGDTL VNRIELKGID FKEDGNILGH KLEYNYNSHN VYIMADKQKN GIKVNFKIRH NIEDGSVQLA DHYQQNTPIG DGPVLLPDNH YLSYQSALSK DPNEKRDHMV LLEFVTAAGI TLGMDELYK'.replace(' ', '')},
    {'name': 'mCitrine', 'sequence': 'MVSKGEELFT GVVPILVELD GDVNGHKFSV SGEGEGDATY GKLTLKFICT TGKLPVPWPT LVTTFGYGLM CFARYPDHMK QHDFFKSAMP EGYVQERTIF FKDDGNYKTR AEVKFEGDTL VNRIELKGID FKEDGNILGH KLEYNYNSHN VYIMADKQKN GIKVNFKIRH NIEDGSVQLA DHYQQNTPIG DGPVLLPDNH YLSYQSKLSK DPNEKRDHMV LLEFVTAAGI TLGMDELYK'.replace(' ', '')},
    {'name': 'Citrine2', 'sequence': 'MVSKGEELFT GVVPILVELD GDVNGHKFSV TGEGEGDATY GKLTLKFICT TGKLPVPWPT LVTTFGYGLT CFARYPDHMK QHDFFKSAMP EGYVQERTIF FKDDGNYKTR AEVKFEGDTL VNRIELKGID FKEDGNILGH KLEYNHNSHY VYIMADKQKN GIKANFKIRH NIEDGSVQLA DHYQQNTPIG DGPVLLPDNH YLSYQSQLSK DPNEERDHTV LLEFVTAAGI TLGMGELYK'.replace(' ', '')},
    {'name': 'Venus', 'sequence': 'MVSKGEELFT GVVPILVELD GDVNGHKFSV SGEGEGDATY GKLTLKLICT TGKLPVPWPT LVTTLGYGLQ CFARYPDHMK QHDFFKSAMP EGYVQERTIF FKDDGNYKTR AEVKFEGDTL VNRIELKGID FKEDGNILGH KLEYNYNSHN VYITADKQKN GIKANFKIRH NIEDGGVQLA DHYQQNTPIG DGPVLLPDNH YLSYQSALSK DPNEKRDHMV LLEFVTAAGI TLGMDELYK'.replace(' ', '')},
    {'name': 'mVenus', 'sequence': 'MVSKGEELFT GVVPILVELD GDVNGHKFSV SGEGEGDATY GKLTLKLICT TGKLPVPWPT LVTTLGYGLQ CFARYPDHMK QHDFFKSAMP EGYVQERTIF FKDDGNYKTR AEVKFEGDTL VNRIELKGID FKEDGNILGH KLEYNYNSHN VYITADKQKN GIKANFKIRH NIEDGGVQLA DHYQQNTPIG DGPVLLPDNH YLSYQSKLSK DPNEKRDHMV LLEFVTAAGI TLGMDELYK'.replace(' ', '')},
    {'name': 'mTurquoise', 'sequence': 'MVSKGEELFT GVVPILVELD GDVNGHKFSV SGEGEGDATY GKLTLKFICT TGKLPVPWPT LVTTLSWGVQ CFARYPDHMK QHDFFKSAMP EGYVQERTIF FKDDGNYKTR AEVKFEGDTL VNRIELKGID FKEDGNILGH KLEYNYISDN VYITADKQKN GIKANFKIRH NIEDGGVQLA DHYQQNTPIG DGPVLLPDNH YLSTQSKLSK DPNEKRDHMV LLEFVTAAGI TLGMDELYK'.replace(' ', '')},
    {'name': 'mEmerald', 'sequence': 'MVSKGEELFT GVVPILVELD GDVNGHKFSV SGEGEGDATY GKLTLKFICT TGKLPVPWPT LVTTLTYGVQ CFARYPDHMK QHDFFKSAMP EGYVQERTIF FKDDGNYKTR AEVKFEGDTL VNRIELKGID FKEDGNILGH KLEYNYNSHK VYITADKQKN GIKVNFKTRH NIEDGSVQLA DHYQQNTPIG DGPVLLPDNH YLSTQSKLSK DPNEKRDHMV LLEFVTAAGI TLGMDELYK'.replace(' ', '')},
    {'name': 'mRuby3', 'sequence': 'MVSKGEELIK ENMRMKVVME GSVNGHQFKC TGEGEGRPYE GVQTMRIKVI EGGPLPFAFD ILATSFMYGS RTFIKYPADI PDFFKQSFPE GFTWERVTRY EDGGVVTVTQ DTSLEDGELV YNVKVRGVNF PSNGPVMQKK TKGWEPNTEM MYPADGGLRG YTDIALKVDG GGHLHCNFVT TYRSKKTVGN IKMPGVHAVD HRLERIEESD NETYVVQREV AVAKYSNLGG GMDELYK'.replace(' ', '')},
    {'name': 'mStayGold2', 'sequence': 'MVSTGEELFT GVVPFKFQLK GTINGKSFTV EGEGEGNSHE GSHKGKYVCT SGKLPMSWAA LGTSFGYGMK YYTKYPSGLK NWFHEVMPEG FTYDRHIQYK GDGSIHAKHQ HFMKNGTYHN IVEFTGQDFK ENSPVLTGDM DVSLPNEVQH IPRDDGVECT VTLTYPLLSD ESKCVEAYQN TIIKPLHNQP APDVPYHWIR KQYTQSKDDT EERDHIIQSE TLEAHLYSRT KLE'.replace(' ', '')},
    ]


aa_seqs += esm_outputs_df[['name', 'sequence']].to_dict(orient='records')



sa_seqs = [{'name': aa_seq['name'], 'sa_sequence': AA_to_SA(aa_seq['sequence']), 'length': len(aa_seq['sequence'])} for aa_seq in aa_seqs]


df = pd.DataFrame(sa_seqs)


yfp_outputs_list = []
fluor_outputs_list = []
for index in tqdm(range(len(df))):
    seq = df['sa_sequence'].iloc[index]
    inputs = tokenizer(seq, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad(): yfp_outputs = saprot_yfp_model(**inputs)
    yfp_outputs_list.append(float(yfp_outputs.logits.detach().cpu().numpy()[0][0]))
    
    with torch.no_grad(): fluor_outputs = saprot_fluor_model(**inputs)
    fluor_outputs_list.append(float(fluor_outputs.logits.detach().cpu().numpy()[0][0]))
    
df['yfp_model_score'] = yfp_outputs_list
df['fluor_model_score'] = fluor_outputs_list
df = df.merge(esm_outputs_df, how='left', on='name')
df['diff'] = df.apply(lambda row: sequence_diff(ref_sequence, row['sequence']), axis=1)
df['n_mutations'] = df['diff'].apply(lambda x: len(x.split('/')) if x != '' else 0)
df = df.sort_values('n_mutations', ascending=False)

In [None]:
sns.pairplot(df.dropna(subset=['pseudo_perplexity']), vars=['yfp_model_score', 'fluor_model_score', 'chromophore_rmsd', 'template_helix_rmsd', 'pseudo_perplexity', 'n_gram_score'])

# viz


In [None]:

# Visualize original YFP structure and generated variant
def visualize_structures(template: ProteinChain, variant: ProteinChain, highlight_residues: List[int] = None):
    view = py3Dmol.view(width=800, height=400, viewergrid=(1, 2))
    
    # Template structure
    template_pdb = template.to_pdb_string()
    view.addModel(template_pdb, "pdb", viewer=(0,0))
    view.setStyle({"cartoon": {"color": "lightgrey"}}, viewer=(0,0))
    if highlight_residues:
        view.addStyle({"resi": highlight_residues}, {"cartoon": {"color": "red"}}, viewer=(0,0))
    
    # Variant structure
    variant_pdb = variant.to_pdb_string()
    view.addModel(variant_pdb, "pdb", viewer=(0,1))
    view.setStyle({"cartoon": {"color": "lightblue"}}, viewer=(0,1))
    if highlight_residues:
        view.addStyle({"resi": highlight_residues}, {"cartoon": {"color": "red"}}, viewer=(0,1))
    
    view.zoomTo()
    return view

print("Visualizing original YFP structure (left) and generated variant (right) with key residues highlighted in red:")
all_key_residues = list(set(sequence_fixed_indices + structure_fixed_indices))
visualize_structures(ref_chain, variant_chain, all_key_residues).show()


# saprot