In [1]:
import sys
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
sys.path.append('../../esm') ## ignore if intsalling esm3
sys.path.append('..')

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import pickle

from esm.models.esm3 import ESM3
from esm.sdk.api import ESM3InferenceClient, ESMProtein, GenerationConfig, SamplingConfig, LogitsConfig, SamplingTrackConfig
from esm.utils.structure.protein_chain import ProteinChain

2024-10-16 11:02:00.566887: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-16 11:02:00.583236: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-16 11:02:00.583260: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-16 11:02:00.594028: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
from DomainPrediction.protein.base import BaseProtein
from DomainPrediction.eval import metrics
from DomainPrediction.utils import helper
from DomainPrediction.utils.constants import *

In [4]:
data_path = '../../Data/round_2_exp'

In [5]:
protein = BaseProtein('../../Data/gxps/gxps_ATC_hm_6mfy.pdb')
Tdomain = protein.get_residues(T_gxps_atc)

sequence_prompt = Tdomain

In [6]:
model: ESM3InferenceClient = ESM3.from_pretrained("esm3_sm_open_v1").to("cuda")

In [7]:
from tqdm import tqdm

In [8]:
def calculate_likelihood(sequence):
    log_prob = 0
    for pos in range(len(sequence)):
        aa = sequence[pos]
        tokens = model.tokenizers.sequence.encode(aa)
        assert len(tokens) == 3
        token = tokens[1]
        assert model.tokenizers.sequence.decode(token) == aa

        sequence_prompt = sequence[:pos] + '_' + sequence[pos+1:]
        esm_protein = ESMProtein(sequence=sequence_prompt)
        protein_tensor = model.encode(esm_protein)
        res = model.logits(protein_tensor, LogitsConfig(
            sequence = True
        ))

        logits = res.logits.sequence[0, 1:-1, :][pos, :].cpu()
        prob = torch.nn.functional.softmax(logits, dim=0)
        log_prob += np.log(prob[token].numpy())

    return log_prob


In [9]:
likelihood_db = {}

In [10]:
likelihood_db[Tdomain] = calculate_likelihood(Tdomain)
print(f'WT T domain LL: {likelihood_db[Tdomain]}')

WT T domain LL: -97.99259356735274


In [11]:
max_LL = likelihood_db[Tdomain]
print(f'WT T domain LL: {likelihood_db[Tdomain]}')
for pos in range(len(Tdomain)):
    aa_wt = Tdomain[pos]
    token_wt = model.tokenizers.sequence.encode(aa_wt)[1]
    assert model.tokenizers.sequence.decode(token_wt) == aa_wt
    
    sequence_prompt = Tdomain[:pos] + '_' + Tdomain[pos+1:]
    esm_protein = ESMProtein(sequence=sequence_prompt)
    protein_tensor = model.encode(esm_protein)
    res = model.logits(protein_tensor, LogitsConfig(
        sequence = True
    ))

    assert res.logits.sequence.shape[1] == len(sequence_prompt) + 2

    logits = res.logits.sequence[0, 1:-1, :][pos, :].cpu()
    softmax_prob = torch.nn.functional.softmax(logits, dim=0)
    
    token_mt = torch.argmax(logits)
    token_mt_prob = torch.argmax(softmax_prob)
    assert token_mt == token_mt_prob
    
    aa_mt = model.tokenizers.sequence.decode(token_mt)
    aa_mt_prob = model.tokenizers.sequence.decode(token_mt_prob)
    assert aa_mt == aa_mt_prob
    
    gen_seq = Tdomain[:pos] + aa_mt + Tdomain[pos+1:]

    if gen_seq not in likelihood_db:
        likelihood_db[gen_seq] = calculate_likelihood(gen_seq)
    
    ll = likelihood_db[gen_seq]

    print(pos, aa_wt, aa_mt, softmax_prob[token_wt], softmax_prob[token_mt], ll)

    if ll > max_LL:
        print(f'mutation at pos {pos} from {aa_wt} to {aa_mt} improved ll from {max_LL} to {ll}')
        max_LL = ll
        best_seq = gen_seq

WT T domain LL: -97.99259356735274
0 G S tensor(0.0756) tensor(0.3232) -97.48324622167274
mutation at pos 0 from G to S improved ll from -97.99259356735274 to -97.48324622167274
1 E P tensor(0.2288) tensor(0.3277) -97.85116550745443
2 I R tensor(0.0601) tensor(0.2433) -96.41527913091704
mutation at pos 2 from I to R improved ll from -97.48324622167274 to -96.41527913091704
3 E E tensor(0.9781) tensor(0.9781) -97.99259356735274
4 I E tensor(0.0192) tensor(0.1820) -96.87033583736047
5 A R tensor(0.1811) tensor(0.2254) -98.73617736296728
6 L L tensor(0.5861) tensor(0.5861) -97.99259356735274
7 A A tensor(0.6852) tensor(0.6852) -97.99259356735274
8 T E tensor(0.0581) tensor(0.1975) -97.12931999890134
9 I I tensor(0.4891) tensor(0.4891) -97.99259356735274
10 W W tensor(0.9569) tensor(0.9569) -97.99259356735274
11 R Q tensor(0.1927) tensor(0.2084) -98.2137359092012
12 E E tensor(0.6387) tensor(0.6387) -97.99259356735274
13 L L tensor(0.7949) tensor(0.7949) -97.99259356735274
14 L L tensor(0.

In [17]:
Tdomain

'GEIEIALATIWRELLNVEQVGRHDSFFALGGHSLLAVRMIERLRRIGLGLSVQTLFQHPTLSVLAQSLVP'

In [16]:
best_seq

'GEIEIALATIWRELLNVEQVGRHDSFFALGGHSLLAVRMIERARRIGLGLSVQTLFQHPTLSVLAQSLVP'

In [18]:
Tdomain = best_seq
max_LL = likelihood_db[Tdomain]
print(f'prev best T domain LL: {likelihood_db[Tdomain]}')
for pos in range(len(Tdomain)):
    aa_wt = Tdomain[pos]
    token_wt = model.tokenizers.sequence.encode(aa_wt)[1]
    assert model.tokenizers.sequence.decode(token_wt) == aa_wt
    
    sequence_prompt = Tdomain[:pos] + '_' + Tdomain[pos+1:]
    esm_protein = ESMProtein(sequence=sequence_prompt)
    protein_tensor = model.encode(esm_protein)
    res = model.logits(protein_tensor, LogitsConfig(
        sequence = True
    ))

    assert res.logits.sequence.shape[1] == len(sequence_prompt) + 2

    logits = res.logits.sequence[0, 1:-1, :][pos, :].cpu()
    softmax_prob = torch.nn.functional.softmax(logits, dim=0)
    
    token_mt = torch.argmax(logits)
    token_mt_prob = torch.argmax(softmax_prob)
    assert token_mt == token_mt_prob
    
    aa_mt = model.tokenizers.sequence.decode(token_mt)
    aa_mt_prob = model.tokenizers.sequence.decode(token_mt_prob)
    assert aa_mt == aa_mt_prob
    
    gen_seq = Tdomain[:pos] + aa_mt + Tdomain[pos+1:]

    if gen_seq not in likelihood_db:
        likelihood_db[gen_seq] = calculate_likelihood(gen_seq)
    
    ll = likelihood_db[gen_seq]

    print(pos, aa_wt, aa_mt, softmax_prob[token_wt], softmax_prob[token_mt], ll)

    if ll > max_LL:
        print(f'mutation at pos {pos} from {aa_wt} to {aa_mt} improved ll from {max_LL} to {ll}')
        max_LL = ll
        best_seq = gen_seq

prev best T domain LL: -94.23718828707933
0 G S tensor(0.0759) tensor(0.3298) -93.02658353932202
mutation at pos 0 from G to S improved ll from -94.23718828707933 to -93.02658353932202
1 E P tensor(0.2320) tensor(0.3323) -93.83913527755067
2 I R tensor(0.0376) tensor(0.2424) -91.05985924089327
mutation at pos 2 from I to R improved ll from -93.02658353932202 to -91.05985924089327
3 E E tensor(0.9822) tensor(0.9822) -94.23718828707933
4 I E tensor(0.0184) tensor(0.1843) -92.84958053473383
5 A A tensor(0.2912) tensor(0.2912) -94.23718828707933
6 L L tensor(0.9472) tensor(0.9472) -94.23718828707933
7 A A tensor(0.7608) tensor(0.7608) -94.23718828707933
8 T E tensor(0.0527) tensor(0.2003) -93.39326152484864
9 I I tensor(0.7037) tensor(0.7037) -94.23718828707933
10 W W tensor(0.9525) tensor(0.9525) -94.23718828707933
11 R Q tensor(0.1959) tensor(0.2022) -94.38500431925058
12 E E tensor(0.6471) tensor(0.6471) -94.23718828707933
13 L L tensor(0.7582) tensor(0.7582) -94.23718828707933
14 L L t

In [19]:
best_seq

'GEIEIALATIWRELLNVEQVGRHDSFFALGGHSLLAVRMIERARRIGLGLSVQTLFQHPTLSALAQSLVP'