In [None]:
import torch

from esm.models.esm3 import ESM3
from esm.sdk.api import (
    ESMProtein,
    SamplingConfig,
    SamplingTrackConfig,
)
from esm.utils.constants.models import ESM3_OPEN_SMALL
from esm.utils.structure.protein_chain import ProteinChain

# Initialize the client
client = ESM3.from_pretrained(ESM3_OPEN_SMALL, device=torch.device("cuda"))

# Load the protein
protein = ProteinChain.from_rcsb("1utn")
protein = ESMProtein.from_protein_chain(protein)
protein.coordinates = None
protein.sequence = 'MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD'

# print(protein)

# Predict function
protein_tensor = client.encode(protein)
inference_output = client.forward_and_sample(
    protein_tensor,
    SamplingConfig(
        sequence=SamplingTrackConfig(),
        structure=SamplingTrackConfig(),
        secondary_structure=SamplingTrackConfig(),
        sasa=SamplingTrackConfig(),
        function=SamplingTrackConfig(only_sample_masked_tokens=False),
        
    ),
)
protein_tensor_with_function = inference_output.protein_tensor
protein_with_function = client.decode(protein_tensor_with_function)
print(protein_with_function)


In [None]:
import re

def split_annotations(annotations, residue_vocab):
    interpro, keywords, residue = [], [], []
    ipr_regex = re.compile(r"IPR\d+")
    for a in annotations:
        if ipr_regex.search(a.label):
            interpro.append(a)
        elif a.label in residue_vocab:
            residue.append(a)
        else:
            keywords.append(a)
    return interpro, keywords, residue

residue_vocab = client.tokenizers.residue_annotations._labels
interpro_list, keyword_list, residue_list = split_annotations(
    protein_with_function.function_annotations, residue_vocab
)




In [None]:
interpro_list

[FunctionAnnotation(label='Acriflavin resistance protein (IPR001036)', start=1, end=25),
 FunctionAnnotation(label='Acriflavin resistance protein (IPR001036)', start=42, end=79),
 FunctionAnnotation(label='Acriflavin resistance protein (IPR001036)', start=95, end=169),
 FunctionAnnotation(label='Acriflavin resistance protein (IPR001036)', start=183, end=220),
 FunctionAnnotation(label='Peptidase S1C (IPR001940)', start=1, end=25),
 FunctionAnnotation(label='Peptidase S1C (IPR001940)', start=42, end=79),
 FunctionAnnotation(label='Peptidase S1C (IPR001940)', start=95, end=169),
 FunctionAnnotation(label='Peptidase S1C (IPR001940)', start=183, end=220),
 FunctionAnnotation(label='Tail specific protease (IPR005151)', start=1, end=25),
 FunctionAnnotation(label='Tail specific protease (IPR005151)', start=42, end=79),
 FunctionAnnotation(label='Tail specific protease (IPR005151)', start=94, end=169),
 FunctionAnnotation(label='Tail specific protease (IPR005151)', start=183, end=220),
 Funct

In [None]:
keyword_list

[FunctionAnnotation(label='a protein', start=1, end=220),
 FunctionAnnotation(label='chymotrypsin', start=1, end=220),
 FunctionAnnotation(label='endopeptidase', start=1, end=220),
 FunctionAnnotation(label='endopeptidase activity', start=1, end=220),
 FunctionAnnotation(label='fold', start=1, end=36),
 FunctionAnnotation(label='fold', start=41, end=220),
 FunctionAnnotation(label='like fold', start=1, end=25),
 FunctionAnnotation(label='like fold', start=42, end=79),
 FunctionAnnotation(label='like fold', start=94, end=169),
 FunctionAnnotation(label='like fold', start=183, end=220),
 FunctionAnnotation(label='pa', start=1, end=25),
 FunctionAnnotation(label='pa', start=42, end=79),
 FunctionAnnotation(label='pa', start=94, end=220),
 FunctionAnnotation(label='peptidase', start=1, end=220),
 FunctionAnnotation(label='peptidase activity', start=1, end=220),
 FunctionAnnotation(label='peptidase s1', start=1, end=220),
 FunctionAnnotation(label='proteolysis', start=1, end=170),
 Function

In [None]:
residue_list

[]

In [7]:
from biotite.database import rcsb

In [8]:
rcsb

<module 'biotite.database.rcsb' from '/home/davidar/anaconda3/envs/esm3_h5py/lib/python3.10/site-packages/biotite/database/rcsb/__init__.py'>