In [11]:
import torch

from esm.models.esm3 import ESM3
from esm.sdk.api import (
    ESMProtein,
    SamplingConfig,
    SamplingTrackConfig,
)
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.structure.protein_chain import ProteinChain

# Initialize the client
client = ESM3.from_pretrained(ESM3_OPEN_SMALL, device=torch.device("cuda:2"))

# Load the protein
protein = ESMProtein(sequence = 'MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD')

print(protein)

# Predict function
protein_tensor = client.encode(protein)
inference_output = client.forward_and_sample(
    protein_tensor,
    SamplingConfig(
        sequence=SamplingTrackConfig(),
        structure=SamplingTrackConfig(),
        secondary_structure=SamplingTrackConfig(),
        sasa=SamplingTrackConfig(),
        function=SamplingTrackConfig(only_sample_masked_tokens=False),
        
    ),
)
protein_tensor_with_function = inference_output.protein_tensor
protein_with_function = client.decode(protein_tensor_with_function)
print('\n--------------\n')
print(protein_with_function)


  state_dict = torch.load(


ESMProtein(sequence='MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD', secondary_structure=None, sasa=None, function_annotations=None, coordinates=None, plddt=None, ptm=None, potential_sequence_of_concern=False)


  state_dict = torch.load(
  state_dict = torch.load(



--------------

ESMProtein(sequence='MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD', secondary_structure='CCCCCCCCCCSCCCCCHHHGCGTCHCHGCCTESGCCCHHSCTTTCCHHHHHHHTHSCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCTCCCETCCCCECCETCCCECEEEEECSSSCCEEEEEEETTTTEEEESECEEECEEEEESCCCCTCCEEEEEEEECCGCCSHSCCEEETSHGHTCTCHSCECESCCEEEESCGEEEECETTSSCCTEECCCESCTCTTCCEEEEEEECHSCTSSTGGEGGCCEEEEEEEECTTSCEEEEECEEEEEESCTGCEHTBHHHHCHCCSCCCCCCCCTCCCCCCCCSCCCCCCCCCCTCTCCSHSCECCTSHHHHHIHHHHHHHCCHSHHHTCCTCCCCCCCCCECCCCCCCCCCCCCCTECCCCCCCC', sasa=[inf, inf, inf, 126.23686218261719, inf, 105.51325988769531, 131.81536865234375, 122.89801788330078, 97.90599822998047, 124.61527252197266, inf,

In [None]:
import re

def split_annotations(annotations, residue_vocab):
    interpro, keywords, residue = [], [], []
    ipr_regex = re.compile(r"IPR\d+")
    for a in annotations:
        if ipr_regex.search(a.label):
            interpro.append(a)
        elif a.label in residue_vocab:
            residue.append(a)
        else:
            keywords.append(a)
    return interpro, keywords, residue

residue_vocab = client.tokenizers.residue_annotations._labels
interpro_list, keyword_list, residue_list = split_annotations(
    protein_with_function.function_annotations, residue_vocab
)


In [13]:
interpro_list

[FunctionAnnotation(label='6-phosphogluconate dehydrogenase, domain 2 (IPR013328)', start=144, end=148),
 FunctionAnnotation(label='Class II aldolase/adducin N-terminal domain superfamily (IPR036409)', start=122, end=127),
 FunctionAnnotation(label='Class II aldolase/adducin N-terminal domain superfamily (IPR036409)', start=141, end=164),
 FunctionAnnotation(label='Class II aldolase/adducin N-terminal domain superfamily (IPR036409)', start=169, end=257),
 FunctionAnnotation(label='Class II aldolase/adducin N-terminal domain superfamily (IPR036409)', start=263, end=273),
 FunctionAnnotation(label='Class II aldolase/adducin N-terminal (IPR001303)', start=141, end=147),
 FunctionAnnotation(label='Class II aldolase/adducin N-terminal (IPR001303)', start=169, end=201),
 FunctionAnnotation(label='Class II aldolase/adducin N-terminal (IPR001303)', start=206, end=247),
 FunctionAnnotation(label='Small GTPase (IPR001806)', start=264, end=271),
 FunctionAnnotation(label='T-box transcription fact

In [14]:
keyword_list

[FunctionAnnotation(label='binding transcription', start=115, end=277),
 FunctionAnnotation(label='dna templated', start=115, end=277),
 FunctionAnnotation(label='factor activity', start=115, end=277),
 FunctionAnnotation(label='of biosynthetic', start=115, end=277),
 FunctionAnnotation(label='of dna', start=115, end=277),
 FunctionAnnotation(label='of gene', start=115, end=277),
 FunctionAnnotation(label='of macromolecule', start=115, end=277),
 FunctionAnnotation(label='of metabolic', start=115, end=277),
 FunctionAnnotation(label='of nitrogen', start=115, end=277),
 FunctionAnnotation(label='of nucleobase', start=115, end=277),
 FunctionAnnotation(label='of primary', start=115, end=277),
 FunctionAnnotation(label='of rna', start=115, end=277),
 FunctionAnnotation(label='regulator activity', start=115, end=277),
 FunctionAnnotation(label='rna biosynthetic', start=115, end=277),
 FunctionAnnotation(label='templated', start=115, end=277),
 FunctionAnnotation(label='templated transcript

In [15]:
residue_list

[FunctionAnnotation(label='trna', start=261, end=265)]

In [7]:
len(residue_vocab)

1474

In [8]:
residue_vocab

['active site',
 'dimer interface',
 'atp binding site',
 'metal: magnesium',
 'act_site: proton acceptor',
 'binding: substrate',
 'binding: nad',
 'substrate binding site',
 'metal binding site',
 'homodimer interface',
 'np_bind: atp',
 'metal: calcium',
 'dna binding site',
 'np_bind: nad',
 'act_site: proton donor',
 'mod_res: n6 (pyridoxal phosphate)lysine',
 'act_site: nucleophile',
 'binding: atp',
 'ligand binding site',
 'np_bind: nadp',
 'abc transporter signature motif',
 'walker b',
 'd loop',
 'np_bind: gtp',
 'q loop/lid',
 'catalytic residue',
 'catalytic site',
 'h loop/switch region',
 'walker a/p loop',
 'phosphorylation site',
 'catalytic residues',
 'binding: s adenosyl l methionine',
 's adenosylmethionine binding site',
 'binding: nadp',
 'abc atpase subunit interface',
 'np_bind: fad',
 'binding: fad',
 'crosslnk: glycyl lysine isopeptide (lys gly) (interchain with g cter in ubiquitin)',
 'binds [4fe 4s] adomet cluster',
 'sequence specific dna binding site',
 '