### In this notebook we want just to make sure that we are sampling properly from ESM

In [2]:
import torch
import numpy as np
import scipy
import biotite.structure
from biotite.structure.io import pdbx, pdb
from biotite.structure.residues import get_residues
from biotite.structure import filter_backbone
from biotite.structure import get_chains
from biotite.sequence import ProteinSequence
from typing import Sequence, Tuple, List
from Bio import SeqIO

import sys
sys.path.insert(1, "/home/luchinoprince/Dropbox/Old_OneDrive/Phd/Second_year/research/Feinauer/esm/")
import esm
#from esm.inverse_folding import util
import esm.pretrained as pretrained

## I import this to try to get deeper on the sampling perplexity
from esm.inverse_folding.features import DihedralFeatures
from esm.inverse_folding.gvp_encoder import GVPEncoder
from esm.inverse_folding.gvp_utils import unflatten_graph
from esm.inverse_folding.gvp_transformer_encoder import GVPTransformerEncoder
from esm.inverse_folding.transformer_decoder import TransformerDecoder
from esm.inverse_folding.util import rotate, CoordBatchConverter 


from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.nn.functional import one_hot

from collections import defaultdict

import matplotlib.pyplot as plt



In [3]:
!wget https://files.rcsb.org/download/5YH2.cif -P .    # save this to the data folder in colab

--2023-05-23 17:10:22--  https://files.rcsb.org/download/5YH2.cif
Resolving files.rcsb.org (files.rcsb.org)... 128.6.158.49
Connecting to files.rcsb.org (files.rcsb.org)|128.6.158.49|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/octet-stream]
Saving to: ‘./5YH2.cif’

5YH2.cif                [    <=>             ]   1,39M  1,99MB/s    in 0,7s    

2023-05-23 17:10:23 (1,99 MB/s) - ‘./5YH2.cif’ saved [1456779]



In [5]:
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.eval()



In [16]:
fpath = '5YH2.cif' # .pdb format is also acceptable
chain_id = 'C'
structure = esm.inverse_folding.util.load_structure(fpath, chain_id)
coords, native_seq = esm.inverse_folding.util.extract_coords_from_structure(structure)
print('Native sequence:')
print(native_seq)
print(len(native_seq))

Found 4 chains: ['A' 'C' 'B' 'D'] 

Loaded chain C

Native sequence:
SVLQSLFEHPLYRTVLPDLTEEDTLFNLNAEIRLYPKAASESYPNWLRFHIGINRYELYSRHNPVIAALLRDLLSQKISSVGMKSGGTQLKLIMSFQNYGQALFKPMKQTREQETPPDFFYFSDFERHNAEIAAFHLDRILDFRRVPPVAGRLVNMTREIRDVTRDKKLWRTFFVSPANNICFYGECSYYCSTEHALCGKPDQIEGSLAAFLPDLALAKRKTWRNPWRRSYHKRKKAEWEVDPDYCDEVKQTPPYDRGTRLLDIMDMTIFDFLMGNMDRHHYETFEKFGNDTFIIHLDNGRGFGKHSHDEMSILVPLTQCCRVKRSTYLRLQLLAKEEYKLSSLMEESLLQDRLVPVLIKPHLEALDRRLRLVLKVLSDCVEKDGFSAVVENDLD
395


In [6]:
import numpy as np

sampled_seq = model.sample(coords, temperature=1)
print('Sampled sequence:', sampled_seq)

recovery = np.mean([(a==b) for a, b in zip(native_seq, sampled_seq)])
print('Sequence recovery:', recovery)

Sampled sequence: AGLDQWRSTLLTRTSDPLLSKEDDLFDLSEVKKMVPRTMPDQDASWLKFHLEIKKFRMYRRNSPTVSKLREELKRRKVMRVEQKTGGRTLTLRFKFEDFGEAWFKPRMARYAEETPIAFFNWERRADSRAVVAAYKLDKLMDLNRTPPCSARRFDMISEIRDVTDDEELRSTFFIRPTRQTAFYGKSEFLSSQEHAITGTPDVMFGAVTAMLPDLSIATRDEWQCPWRESDSKDVLSAWKIDPNYFNMIKLMPHLATGAALLNQADAWIFDHLMGNCDRHHFITFERFGKNTSFIIFDNGSGFGRSYHLNMDILEPLLQGCSMREVLYKRLVQLSQTRYSLERLLKDILGKDSGYPVLHEPFLKQLDVRLAEAVQVVKDCVKKVGEKKCLLKDTL
Sequence recovery: 0.389873417721519


In [8]:
sampled_seq = model.sample(coords, temperature=1e-6)
print('Sampled sequence:', sampled_seq)

recovery = np.mean([(a==b) for a, b in zip(native_seq, sampled_seq)])
print('Sequence recovery:', recovery)

Sampled sequence: ERLRAWWASPLTQLPDPGLSEEDLLFDPEELLALLPEEEEEELPAWLRFWTGIRRRRLYERESPDVEELLRRLRTARVRRVGQKSGGRSLVLRFEFEDLGSAGFKPRVAELDEETPPEWGFWEVLQRARAVVAAYRLDRLLDLRQVPPAAGRRLDLVTELRDVTDDEELRSTFFVTPEEELCFYGRCEFRCDREHALCGRPDVVEGALVAELPDERIAPRGVYLNPWAHARERDVEALWEVDPDYCEYVRRLPPFREGRLLLELANAYVFDFLMGNADRHTFSTFERFGLDTFLLLLDNGFGFGRADYLDERILRPLEQCCLLSERLYRRLLALSEEEFSLEELMEEELGRDELWPVLARPFLRQLDRRLRRVLEVLEECEEECGREEVLVREEG
Sequence recovery: 0.4582278481012658
