In [1]:
from dataclasses import dataclass, field
from typing import List, Set, Tuple, Optional
# import numpy as np
# import torch

from esm.sdk.api import ESMProtein
from protein_prompt_template import ProteinPromptTemplate
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.structure.aligner import Aligner
from esm.sdk.api import ESMProtein, GenerationConfig, ESM3InferenceClient
from esm.tokenization import EsmSequenceTokenizer
from esm.models.esm3 import ESM3
from biotite.structure import annotate_sse
from transformers import EsmTokenizer, EsmForSequenceClassification
from peft import PeftConfig, PeftModelForSequenceClassification
from huggingface_hub import snapshot_download
import pandas as pd
from tqdm import tqdm
import torch
from esm.utils.constants import esm3 as C

In [2]:
ref_sequence: str = 'MVSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFGYGLQCFARYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSYQSALSKDPNEKRDHMVLLEFVTAAGITLGMDELYK'
sequence_fixed_indices = None
structure_fixed_indices= None
n_iterations = 20
n_processes = 20
n_generated_per_iter = 400 # Changed from variants_per_iteration
n_propagated_per_iter = 10
masking_function = None
fitness_function = None

if sequence_fixed_indices is None:
    sequence_fixed_indices = [1, 62, 65, 66, 67, 96, 222]
if structure_fixed_indices is None:
    structure_fixed_indices = list(range(58, 72)) + [96, 222]


In [3]:
def pad_coords(coords, start_pad, end_pad):
            return torch.cat([
                torch.full((start_pad, 37, 3), float('nan')), 
                coords, 
                torch.full((end_pad, 37, 3), float('nan'))
                ], dim=0)

In [4]:
 # Load YFP template
pdb_id = "1YFP"  # Enhanced Yellow Fluorescent Protein
chain_id = "A"
ref_chain = ProteinChain.from_rcsb(pdb_id, chain_id)

ref_pdb_start = ref_chain.residue_index[0]
ref_pdb_end = ref_chain.residue_index[-1]
end_padding = len(ref_sequence) - ref_pdb_end + 2

coords = torch.tensor(ref_chain.atom37_positions)
padded_coords = pad_coords(coords, ref_pdb_start, end_padding)

template_esm_protein = ESMProtein(
    sequence=ref_sequence,
    coordinates=padded_coords,
)


In [None]:
# protein_chain = template_esm_protein.to_protein_chain()
# secondary_structure = ''.join(protein_chain.dssp()).replace('X', C.MASK_STR_SHORT)
# template_esm_protein.secondary_structure = secondary_structure

In [5]:
secondary_structure = '___CGGGGGSSCEEEEEEEEEEETTEEEEEEEEEEEETTTTEEEEEEEESSSSCSSCGGGGTTTCCGGGCBCCTTTGGGCHHHHTTTTCEEEEEEEEETTSCEEEEEEEEEEETTEEEEEEEEEEECCCTTSTTTTTCBCSCCCCEEEEEEEETTTTEEEEEEEEEEEBTTSCEEEEEEEEEEEESSSSCCCCCCSCEEEEEEEEECCTTCCSSEEEEEEEEEEECC____________'

In [6]:
template_esm_protein.secondary_structure = secondary_structure

In [7]:

template = ProteinPromptTemplate.from_esm_protein(
        template_esm_protein,
        fixed_seq_indices=sequence_fixed_indices,
        fixed_struc_indices=structure_fixed_indices,
    )


In [8]:
prompt = template.apply_masking(mask_fraction=0.02)

In [9]:
sequence_config = GenerationConfig(
        track="sequence",
        num_steps=len(prompt.get_masked_indices()),
    )

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
# ESM3 model
esm_model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to(device)
esm_tokenizer = EsmSequenceTokenizer()

Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

In [11]:
p = prompt.to_esm_protein()

In [12]:
p.secondary_structure

'___CGGGGGSSCEEEEEEEEEEETTEEEEEEEEEEEETTTTEEEEEEEESSSSCSSCGGGGTTTCCGGGCBCCTTTGGGCHHHHTTTTCEEEEEEEEETTSCEEEEEEEEEEETTEEEEEEEEEEECCCTTSTTTTTCBCSCCCCEEEEEEEETTTTEEEEEEEEEEEBTTSCEEEEEEEEEEEESSSSCCCCCCSCEEEEEEEEECCTTCCSSEEEEEEEEEEECC____________'

In [None]:
esm_model.encode()

In [None]:
output = esm_model.generate(
    prompt.to_esm_protein(),
    sequence_config,
)