In [1]:
import sys
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
sys.path.append('../../esm') ## ignore if intsalling esm3
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import torch

from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig
from esm.utils.structure.protein_chain import ProteinChain

2024-10-18 12:38:34.856887: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-18 12:38:35.769272: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-18 12:38:35.774387: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-18 12:38:37.087965: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
from DomainPrediction import BaseProtein
from DomainPrediction.eval import metrics
from DomainPrediction.utils import helper
from DomainPrediction.utils.constants import *

In [18]:
protein = ProteinChain.from_pdb('../../Data/gxps/gxps_ATC_hm_6mfy.pdb')
T_domain = ProteinChain.from_pdb('../../Data/gxps/gxps_T.pdb')

sequence_prompt = ''.join([protein[i].sequence if i in A_gxps_atc + C_gxps_atc else '_' for i in range(len(protein))])
structure_prompt = torch.full((len(sequence_prompt), 37, 3), np.nan)

In [6]:
sequence_prompt

'VCVHQLFEQQIEKTPDAIAVIYENQTLSYAELNARANRLAHQLIALGVAPDQRVAICVTRSLARIIGLLAVLKAGGAYVPLDPAYPGERLAYMLTDATPVILMADNVGRAALSEDILATLTVLDPNTLLEQPDHNPQVSGLTPQHLAYVIYTSGSTGRPKGVMIEHRSVVNLTLTQITQFDVCATSRMLQFASFGFDASVWEIMMALSCGAMLVIPTETVRQDPQRLWRYLEEQAITHACLTPAMFHDGTDLPAIAIKPTLIFAGEAPSPALFQALCSRADLFNAYGPTEITVCATTWDCPADYTGGVIPIGSPVANKRLYLLDEHRQPVPLGTVGELYIGGVGVARGYLNRPELTAERFLNDPFSDETNARMYRAGDLARYLPDGNLVFVGRNDQQVKIRGFRIEPGEIEARLVEHSEVSEALVLALGDGQDKRLVAYVVALADDGLATKLREHLSDILPDYMIPAAFVRLDAFPLTPNGKLDRRSLP___________________________________________________________________________________________________________________QAEIDRIVEQVPGGIANIQDIYALSPLQDGILFHHLLANEGDPYLLITQQAFADRPLLNRYLAAVQQVVDRHDILRTAFIWEGLSVPAQVICRQAPLSVTELTLNPADGAISNQLAQRFDPRRHRIDLNQAPLLRFVVAQESDGRWILLQLLHHLIGDHTTLEVMNSEVQACLLGQMDSLPAPVPFRHLVAQARQGVSQAEHTRFFTDMLAEVDEPTLLFGLAEAHHDGSQVTESHRMLTAGLNERLRGQARRLGVSVAALCHLAWAQVLSRTSGQTQVVFGTVLFGRMQAGEGSDSGMGLFINTLPLRLDIDNTPVRDSVRAAHSRLAGLLEHEHASLALAQRCSGVESGTPLFNALLNYRHNTQPVTPDEIVSGIEFLGAQERTNYPFVLSVE

In [7]:
''.join([protein[i].sequence for i in range(len(protein)) if i in T_gxps_atc])

'GEIEIALATIWRELLNVEQVGRHDSFFALGGHSLLAVRMIERLRRIGLGLSVQTLFQHPTLSVLAQSLVP'

In [24]:
structure_prompt[T_gxps_atc] = torch.tensor(T_domain.atom37_positions)

In [25]:
print(
    "Indices with structure conditioning: ",
    torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist(),
)

Indices with structure conditioning:  [505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574]


In [27]:
esm_protein = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)

In [29]:
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cuda")

In [31]:
protein_tensor = model.encode(esm_protein)

In [None]:
# ## sequence as input
# esm_protein = ESMProtein(sequence=sequence_prompt)
# ## structure + sequence as input
# esm_protein = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)

In [None]:
fasta_file = '../../Data/esm3_experiments/6mfw_exp/6mfw_esm3_1000.fasta'

In [None]:
N_GENERATIONS = 100
temperature = 0.5
run_structure = False
print(f'T domain: {protein[T_6mfw].sequence}')
for idx in range(N_GENERATIONS):
    
    if run_structure and idx > 1:
        run_structure = False
        print('stopping structure prediction')

    sequence_prediction_config = GenerationConfig(
        track="sequence", 
        num_steps=sequence_prompt.count("_") // 2, 
        temperature=temperature
    )
    esm_protein = ESMProtein(sequence=sequence_prompt)
    generated_protein = model.generate(esm_protein, sequence_prediction_config)

    if run_structure:
        ## generate structure from the generated sequence
        structure_prediction_config = GenerationConfig(
            track="structure",
            num_steps=len(generated_protein) // 8,
            temperature=temperature, 
        )
        structure_prediction_prompt = ESMProtein(sequence=generated_protein.sequence)
        structure_prediction = model.generate(structure_prediction_prompt, structure_prediction_config)

        assert generated_protein.sequence == structure_prediction.sequence
        # structure_prediction.to_pdb(os.path.join(pdbfile_loc, gen_idx))

    print(f"T domain: {''.join([generated_protein.sequence[i] for i in range(len(generated_protein.sequence)) if i in T_6mfw])}")

    assert protein[A_6mfw].sequence == ''.join([generated_protein.sequence[i] for i in range(len(generated_protein.sequence)) if i in A_6mfw])
    assert protein[C_6mfw].sequence == ''.join([generated_protein.sequence[i] for i in range(len(generated_protein.sequence)) if i in C_6mfw])

    seq_dict = {}
    gen_idx = f'gxps_ATC_esm3_temp_{temperature}_gen_{idx}'
    seq_dict[gen_idx] = generated_protein.sequence

    helper.create_fasta(seq_dict, fasta_file, append=True)