In [1]:
import sys
sys.path.append('..')

In [2]:
import torch
import esm

import numpy as np
import matplotlib.pyplot as plt

In [3]:
from utils.base import BaseProtein

In [4]:
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()

In [5]:
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

In [6]:
protein = BaseProtein(file='../../Data/GxpS_ATC.pdb')
T = [i for i in range(538,608)] ## 539-608

In [7]:
protein.get_residues(T) ## T domain

'GEIEIALATIWRELLNVEQVGRHDSFFALGGHSLLAVRMIERLRRIGLGLSVQTLFQHPTLSVLAQSLVP'

In [8]:
masked_query = ''.join(['<mask>' if i in T else protein.sequence[i] for i in range(len(protein.sequence))])

In [9]:
masked_query

'PQQPVTAIDILSSSERELLLENWNATEEPYPTQVCVHQLFEQQIEKTPDAIAVIYENQTLSYAELNARANRLAHQLIALGVAPDQRVAICVTRSLARIIGLLAVLKAGGAYVPLDPAYPGERLAYMLTDATPVILMADNVGRAALSEDILATLTVLDPNTLLEQPDHNPQVSGLTPQHLAYVIYTSGSTGRPKGVMIEHRSVVNLTLTQITQFDVCATSRMLQFASFGFDASVWEIMMALSCGAMLVIPTETVRQDPQRLWRYLEEQAITHACLTPAMFHDGTDLPAIAIKPTLIFAGEAPSPALFQALCSRADLFNAYGPTEITVCATTWDCPADYTGGVIPIGSPVANKRLYLLDEHRQPVPLGTVGELYIGGVGVARGYLNRPELTAERFLNDPFSDETNARMYRAGDLARYLPDGNLVFVGRNDQQVKIRGFRIEPGEIEARLVEHSEVSEALVLALGDGQDKRLVAYVVALADDGLATKLREHLSDILPDYMIPAAFVRLDAFPLTPNGKLDRRSLPAPGEDAFARQAYQAPQ<mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask><mask>HREISVPDNGITADTTVLTPAMLPLIDLTQAEIDRIVEQVP

In [130]:
data = [
    ("protein1", masked_query)
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

In [131]:
# batch_tokens = batch_tokens.to(torch.device('cuda'))

In [132]:
batch_tokens.shape

torch.Size([1, 1105])

In [133]:
# ## Random Mask Approach
# with torch.no_grad():
#     fix = 1
#     inv = {v:k for k,v in alphabet.tok_to_idx.items()}
#     for i in range(len(T)):
#         print(f'run {i}')
#         results = model(batch_tokens, repr_layers=[33], return_contacts=True)
#         prob = torch.nn.functional.softmax(results['logits'], dim=-1)
#         masked_positions = torch.nonzero(batch_tokens == 32, as_tuple=False)
#         print(f'num maxed pos {masked_positions.size(0)}')
#         if masked_positions.size(0) > fix:
#             selected_positions = masked_positions[torch.randperm(masked_positions.size(0))[:fix]]
#             batch_tokens[0, selected_positions[:, 1]] = torch.multinomial(prob[:, selected_positions[:, 1], :][0], 1).flatten()
#         else:
#             selected_positions = masked_positions
#             batch_tokens[0, selected_positions[:, 1]] = torch.multinomial(prob[:, selected_positions[:, 1], :][0], 1).flatten()

#         print(f" T domain : {''.join([inv[i] for i in batch_tokens[0, 1:-1][T].numpy()])}")

#         if torch.nonzero(batch_tokens == 32, as_tuple=False).size(0) == 0:
#             break

In [134]:
## Min Entropy position
with torch.no_grad():
    fix = 2
    inv = {v:k for k,v in alphabet.tok_to_idx.items()}
    for i in range(len(T)):
        print(f'run {i}')
        results = model(batch_tokens, repr_layers=[33], return_contacts=True)
        prob = torch.nn.functional.softmax(results['logits'], dim=-1)
        masked_positions = torch.nonzero(batch_tokens == 32, as_tuple=False)
        print(f'num maxed pos {masked_positions.size(0)}')
        if masked_positions.size(0) > fix:
            masked_probs = prob[0 ,masked_positions[:, 1], :]
            entropies = -torch.sum(masked_probs * torch.log(masked_probs + 1e-10), dim=-1)
            indices = torch.argsort(entropies)[:fix]
            selected_positions = masked_positions[:, 1][indices]
            batch_tokens[0, selected_positions] = torch.multinomial(prob[:, selected_positions, :][0], 1).flatten()
        else:
            selected_positions = masked_positions[:, 1]
            batch_tokens[0, selected_positions] = torch.multinomial(prob[:, selected_positions, :][0], 1).flatten()

        print(f" T domain : {''.join([inv[i] if inv[i] != '<mask>' else '_' for i in batch_tokens[0, 1:-1][T].numpy()])}")

        if torch.nonzero(batch_tokens == 32, as_tuple=False).size(0) == 0:
            break

run 0
num maxed pos 70
 T domain : S__E__________________________________________________________________
run 1
num maxed pos 68
 T domain : S__E__LL______________________________________________________________
run 2
num maxed pos 66
 T domain : S__E__LL__I___L_______________________________________________________
run 3
num maxed pos 64
 T domain : S__E__LL__I___L___________________________________________________I__P
run 4
num maxed pos 62
 T domain : SE_E__LL__I__LL___________________________________________________I__P
run 5
num maxed pos 60
 T domain : SE_EH_LL__IQ_LL___________________________________________________I__P
run 6
num maxed pos 58
 T domain : SE_EH_LL_LIQ_LLR__________________________________________________I__P
run 7
num maxed pos 56
 T domain : SETEH_LL_LIQ_LLRR_________________________________________________I__P
run 8
num maxed pos 54
 T domain : SETEH_LLQLIQGLLRR_________________________________________________I__P
run 9
num maxed pos 52
 T domain : SETEHCLLQLI

In [None]:
protein.get_residues(T)

In [None]:
masked_positions

In [None]:
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]

In [None]:
results['logits'].shape, token_representations.shape

In [None]:
results['logits']

In [None]:
prob = torch.nn.functional.softmax(results['logits'], dim=-1)

In [None]:
prob[:, selected_positions[:, 1], :][0]

In [None]:
i = 596
plt.plot(prob[0, i, :])
plt.yscale('log')
torch.argmax(prob[0, i, :])

In [None]:
inv = {v:k for k,v in alphabet.tok_to_idx.items()}

In [None]:
inv

In [None]:
torch.argmax(prob[0, T, :], dim=-1).numpy()

In [None]:
''.join([inv[i] for i in torch.argmax(prob[:,1:-1,:][0, T, :], dim=-1).numpy()])

In [None]:
protein.get_residues(T)

In [None]:
print(T)

In [None]:
select = [608,609,700,701,702,703,704,705]
print(protein.get_residues(select))
print(''.join([inv[i] for i in torch.argmax(prob[:,1:-1,:][0, select, :], dim=-1).numpy()]))

In [None]:
prob[:,1:-1,:].shape

In [None]:
len(protein.sequence)