In [10]:
import pickle, os
from tqdm.notebook import tqdm
import pandas as pd

# Data

In [3]:
featuresf = "4.Features/features.pkl"

with open(featuresf, "rb") as f:
    featuresd = pickle.load(f)

featuresd

{'6ta3':     Residues                                                          \
          pdb label_entity_id label_asym_id label_seq_id auth_asym_id   
 0       6ta3               2             B            1            A   
 1       6ta3               2             B            2            A   
 2       6ta3               2             B            3            A   
 3       6ta3               2             B            4            A   
 4       6ta3               2             B            5            A   
 ..       ...             ...           ...          ...          ...   
 767     6ta3               3             C          387            K   
 768     6ta3               3             C          388            K   
 769     6ta3               3             C          389            K   
 770     6ta3               3             C          390            K   
 771     6ta3               3             C          391            K   
 
                                   Label 

In [4]:
representatives = list(featuresd.keys())

len(representatives), representatives

(289,
 ['6ta3',
  '5kwj',
  '6t4k',
  '5mo6',
  '3o0n',
  '1tjp',
  '4bqh',
  '1og0',
  '5jv1',
  '3auy',
  '5dl1',
  '1qqb',
  '6hgj',
  '4aau',
  '7juv',
  '6nzf',
  '5uyy',
  '5xez',
  '2b08',
  '2c2b',
  '3i0s',
  '4dpq',
  '2d5x',
  '7amj',
  '2qmx',
  '3e5u',
  '4bo1',
  '6mo1',
  '2him',
  '5aa4',
  '5h08',
  '3kq7',
  '2qei',
  '8jii',
  '5y66',
  '6pgr',
  '2ha2',
  '5dzh',
  '7anw',
  '6e5g',
  '4pcu',
  '6bky',
  '6gfm',
  '4ejl',
  '5f2e',
  '6r4d',
  '7ehm',
  '6o0h',
  '4a79',
  '3dc2',
  '6ueg',
  '3kh5',
  '8hnn',
  '5ngz',
  '7mew',
  '5m1a',
  '3smq',
  '5hue',
  '4lrl',
  '5urj',
  '6huj',
  '1hkb',
  '2ldb',
  '2cdq',
  '5lvx',
  '2dz9',
  '6bvn',
  '1jlr',
  '5fii',
  '6sbv',
  '4bxc',
  '5ifu',
  '7ljc',
  '2jfz',
  '6o0n',
  '5swo',
  '6w6d',
  '4dn0',
  '6dhm',
  '5on7',
  '5hot',
  '3kjn',
  '6shq',
  '4coj',
  '6qjw',
  '5tkb',
  '5mro',
  '6i5g',
  '2bu6',
  '4zsk',
  '5x33',
  '5m0s',
  '6f7i',
  '6usa',
  '2vk1',
  '5kcv',
  '3f6g',
  '4avc',
  '3cmu',
  '5

In [6]:
seqsf = "2.Clustering/seqs.pkl"

with open(seqsf, "rb") as f:
    seqs = {
        (pdb, entry_id): seq
        for (pdb, entry_id, asym_id), seq in pickle.load(f).items()
        if pdb in representatives
    }

seqs

{('5kwj',
  '1'): 'MFSRPGLPVEYLQVPSASMGRDIKVQFQGGGPHAVYLLDGLRAQDDYNGWDINTPAFEEYYQSGLSVIMPVGGQSSFYTDWYQPSQSNGQNYTYKWETFLTREMPAWLQANKGVSPTGNAAVGLSMSGGSALILAAYYPQQFPYAASLSGFLNPSEGWWPTLIGLAMNDSGGYNANSMWGPSSDPAWKRNDPMVQIPRLVANNTRIWVYCGNGTPSDLGGDNIPAKFLEGLTLRTNQTFRDTYAADGGRNGVFNFPPNGTHSWPYWNEQLVAMKADIQHVLNGATPPAAPAAPAALEHHHHHH',
 ('6t4k',
  '1'): 'GSSHHHHHHSSGLVPRGSHMASLTEIEHLVQSVCKSYRETCQLRLEDLLRQRSNIFSREEVTGYQRKSMWEMWERCAHHLTEAIQYVVEFAKRLSGFMELCQNDQIVLLKAGAMEVVLVRMCRAYNADNRTVFFEGKYGGMELFRALGCSELISSIFDFSHSLSALHFSEDEIALYTALVLINAHRPGLQEKRKVEQLQYNLELAFHHHLHKTHRQSILAKLPPKGKLRSLCSQHVERLQIFQHLHPIVVQAAFPPLYKELFS',
 ('5mo6',
  '1'): 'GSMDIEFDDDADDDGSGSGSGSGSSGPVPSRARVYTDVNTHRPSEYWDYESHVVEWGNQDDYQLVRKLGRGKYSEVFEAINITNNEKVVVKILKPVKKKKIKREIKILENLRGGPNIITLADIVKDPVSRTPALVFEHVNNTDFKQLYQTLTDYDIRFYMYEILKALDYCHSMGIMHRDVKPHNVMIDHEHRKLRLIDWGLAEFYHPGQEYNVRVASRYFKGPELLVDYQMYDYSLDMWSLGCMLASMIFRKEPFFHGHDNYDQLVRIAKVLGTEDLYDYIDKYNIELDPRFNDILGRHSRKRWERFVHSENQHLVSPEALDFLDKLLRYDHQSRLTAREAMEHPYFYTVVK',
 ('3o0n',
  '1')

# Embeddings

In [9]:
embeddings_path = "5.Embeddings"
os.makedirs(f"{embeddings_path}", exist_ok=True)

In [11]:
seq_embsf = f"{embeddings_path}/seq_embeddings.pkl"

if not os.path.isfile(seq_embsf):
    from utils.embeddings import get_embs
    
    seq_embsd = {}
    for (pdb, entity_id), seq in tqdm(seqs.items(), smoothing=0):
        if pdb not in seq_embsd:
            seq_embsd[pdb] = {}
        seq_embsd[pdb][entity_id] = {
            "seq": seq,
            "embs": get_embs(seq)
        }
        
    with open(seq_embsf, "wb") as f:
        pickle.dump(seq_embsd, f)
else:
    with open(seq_embsf, "rb") as f:
        seq_embsd = pickle.load(f)

len(seq_embsd), seq_embsd

Some weights of EsmModel were not initialized from the model checkpoint at /data/fnerin/huggingface/hub/models--facebook--esm2_t33_650M_UR50D/snapshots/08e4846e537177426273712802403f7ba8261b6c and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/309 [00:00<?, ?it/s]

(289,
 {'5kwj': {'1': {'seq': 'MFSRPGLPVEYLQVPSASMGRDIKVQFQGGGPHAVYLLDGLRAQDDYNGWDINTPAFEEYYQSGLSVIMPVGGQSSFYTDWYQPSQSNGQNYTYKWETFLTREMPAWLQANKGVSPTGNAAVGLSMSGGSALILAAYYPQQFPYAASLSGFLNPSEGWWPTLIGLAMNDSGGYNANSMWGPSSDPAWKRNDPMVQIPRLVANNTRIWVYCGNGTPSDLGGDNIPAKFLEGLTLRTNQTFRDTYAADGGRNGVFNFPPNGTHSWPYWNEQLVAMKADIQHVLNGATPPAAPAAPAALEHHHHHH',
    'embs': array([[ 0.18417343,  0.04812526, -0.1441038 , ..., -0.01353498,
             0.18469599, -0.00558302],
           [ 0.10449459,  0.01706475, -0.01919748, ..., -0.1006778 ,
             0.2963119 ,  0.03123129],
           [ 0.02207799,  0.20550573,  0.03463919, ..., -0.04366823,
             0.09699859,  0.12048241],
           ...,
           [ 0.10153432,  0.14793044, -0.05333032, ..., -0.32693347,
             0.25719202, -0.20861787],
           [ 0.13853206,  0.12871557, -0.07205319, ..., -0.22483717,
             0.18781076, -0.2513377 ],
           [ 0.12047419,  0.05020089, -0.16103242, ..., -0.2516187 ,
             0.26720938, -0.25

In [12]:
embsf = f"{embeddings_path}/embeddings.pkl"

if not os.path.isfile(embsf):
    embsd = {}
    for pdb, embs in tqdm(seq_embsd.items(), smoothing=0):
        residues = featuresd[pdb][["Residues", "Label"]]

        embsd[pdb] = residues.merge(
            pd.concat((
                pd.concat(
                    (
                        pd.Series([entity_id]*len(entd["seq"]), name="label_entity_id"),
                        pd.Series(range(1, len(entd["seq"])+1), name="label_seq_id", dtype=str),
                        pd.DataFrame(entd["embs"])
                    ),
                    axis=1,
                    keys=["Residues", "Residues", "Embeddings"]
                )
                for entity_id, entd in embs.items()
            )),
            on=[("Residues", "label_entity_id"), ("Residues", "label_seq_id")]
        )
        
    with open(embsf, "wb") as f:
        pickle.dump(embsd, f)
else:
    with open(embsf, "rb") as f:
        embsd = pickle.load(f)

len(embsd), embsd

  0%|          | 0/289 [00:00<?, ?it/s]

(289,
 {'5kwj':     Residues                                                          \
           pdb label_entity_id label_asym_id label_seq_id auth_asym_id   
  0       5kwj               1             B            1            B   
  1       5kwj               1             B            2            B   
  2       5kwj               1             B            3            B   
  3       5kwj               1             B            4            B   
  4       5kwj               1             B            5            B   
  ..       ...             ...           ...          ...          ...   
  278     5kwj               1             B          279            B   
  279     5kwj               1             B          280            B   
  280     5kwj               1             B          281            B   
  281     5kwj               1             B          282            B   
  282     5kwj               1             B          283            B   
  
                     

## Per-pocket

In [13]:
from utils.pocket_utils import get_mean_pocket_em

In [14]:
pockets = (
    pd.read_pickle("3.Pockets/old_fpocket/pockets.pkl")[
        ['pdb', 'pocket', 'label']
    ].query(f"pdb in {list(featuresd.keys())}")
)
pockets

Unnamed: 0,pdb,pocket,label
0,6ta3,pocket16,0
1,6ta3,pocket18,0
2,6ta3,pocket11,0
3,6ta3,pocket15,0
4,6ta3,pocket7,0
...,...,...,...
8136,3d2p,pocket5,0
8137,3d2p,pocket8,0
8138,3d2p,pocket3,0
8139,3d2p,pocket12,0


In [15]:
pockets = pd.concat(
    (
        pockets[["pdb", "pocket"]],
        pockets["label"]
    ),
    axis=1,
    keys=["Pockets", "Label"]
)

pockets

Unnamed: 0_level_0,Pockets,Pockets,Label
Unnamed: 0_level_1,pdb,pocket,label
0,6ta3,pocket16,0
1,6ta3,pocket18,0
2,6ta3,pocket11,0
3,6ta3,pocket15,0
4,6ta3,pocket7,0
...,...,...,...
8136,3d2p,pocket5,0
8137,3d2p,pocket8,0
8138,3d2p,pocket3,0
8139,3d2p,pocket12,0


In [17]:
pockets_embeddings = pd.concat(
    (
        pockets,
        pockets.apply(
            lambda row: get_mean_pocket_features(
                row[("Pockets", "pdb")],
                row[("Pockets", "pocket")],
                pdb_features = embsd[row[("Pockets", "pdb")]],
                pockets_path = "3.Pockets/old_fpocket/pockets" # f"{pockets_path}/{pdb}/{pdb}_out/pockets/{pocket}_atm.cif"
            ), 
            axis=1 
        )
    ),
    axis=1,
)

In [18]:
pockets_embeddings

Unnamed: 0_level_0,Pockets,Pockets,Label,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings,Embeddings
Unnamed: 0_level_1,pdb,pocket,label,0,1,2,3,4,5,6,...,1270,1271,1272,1273,1274,1275,1276,1277,1278,1279
0,6ta3,pocket16,0,-0.094970,-0.116930,0.040117,-0.058876,-0.174019,-0.282384,-0.125708,...,0.163619,-0.133219,-0.005906,-0.028193,-0.035316,-0.044516,0.102737,-0.064026,-0.034069,0.033591
1,6ta3,pocket18,0,0.020169,-0.078439,0.061870,0.002930,-0.074058,-0.130024,0.047107,...,-0.011803,-0.083614,0.026626,0.024293,-0.019584,-0.023984,0.138034,-0.198573,0.043331,0.090689
2,6ta3,pocket11,0,0.007830,-0.014264,-0.099724,0.063862,-0.098108,-0.089830,0.097989,...,0.021253,-0.044613,-0.084751,-0.017245,0.060706,-0.143717,-0.012969,0.001674,-0.008403,0.172018
3,6ta3,pocket15,0,-0.119789,0.135076,-0.017253,0.089593,-0.042903,-0.116645,-0.001007,...,-0.004941,-0.076110,-0.009547,0.089223,-0.056367,-0.224084,-0.022458,0.010889,0.020006,0.102865
4,6ta3,pocket7,0,-0.022467,0.067619,-0.017286,0.028015,-0.088452,-0.100998,0.031924,...,0.029544,-0.008624,-0.026880,0.041739,-0.005503,-0.125429,0.008519,0.015348,-0.030216,0.105604
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8136,3d2p,pocket5,0,-0.016085,0.121749,0.129790,-0.018972,-0.059095,0.023959,-0.047728,...,0.057683,0.014217,-0.089059,0.065058,0.054057,0.080963,-0.065427,-0.157244,-0.102713,-0.123586
8137,3d2p,pocket8,0,0.031225,-0.002326,0.070191,-0.046042,-0.103540,-0.136247,0.075616,...,-0.034458,-0.093629,0.003621,-0.007728,-0.058156,0.115197,0.057976,-0.006428,-0.015101,-0.075469
8138,3d2p,pocket3,0,0.055994,-0.103168,0.083710,-0.076272,-0.240857,-0.084090,0.107200,...,0.097598,-0.048996,0.007177,0.045579,-0.029063,0.036612,0.033029,-0.094900,-0.026578,-0.138080
8139,3d2p,pocket12,0,-0.038334,-0.179599,0.115259,-0.068121,-0.260114,-0.112897,0.090455,...,0.016604,-0.023116,-0.037841,-0.069211,-0.014965,0.095894,-0.008730,-0.084798,-0.094889,-0.016773


In [19]:
pockets_embeddingsf = f"{embeddings_path}/pockets_embeddings.pkl"

if not os.path.isfile(pockets_embeddingsf):
    pockets_embeddingsd = {
        pdb: pks
        for pdb, pks in pockets_embeddings.groupby(("Pockets", "pdb"), sort=False)
    }
    with open(pockets_embeddingsf, "wb") as f:
        pickle.dump(pockets_embeddingsd, f)
else:
    with open(pockets_embeddingsf, "rb") as f:
        pockets_embeddingsd = pickle.load(f)

len(pockets_embeddingsd), pockets_embeddingsd

(289,
 {'6ta3':    Pockets           Label Embeddings                                          \
         pdb    pocket label          0         1         2         3         4   
  0     6ta3  pocket16     0  -0.094970 -0.116930  0.040117 -0.058876 -0.174019   
  1     6ta3  pocket18     0   0.020169 -0.078439  0.061870  0.002930 -0.074058   
  2     6ta3  pocket11     0   0.007830 -0.014264 -0.099724  0.063862 -0.098108   
  3     6ta3  pocket15     0  -0.119789  0.135076 -0.017253  0.089593 -0.042903   
  4     6ta3   pocket7     0  -0.022467  0.067619 -0.017286  0.028015 -0.088452   
  5     6ta3  pocket22     0   0.031462 -0.130496 -0.006135  0.123543 -0.165853   
  6     6ta3  pocket13     0   0.004531 -0.002420 -0.084282 -0.092910 -0.095742   
  7     6ta3  pocket14     0  -0.060161  0.063076 -0.102772  0.040867 -0.174732   
  8     6ta3   pocket1     0  -0.004252  0.047088 -0.040080  0.002065 -0.052970   
  9     6ta3   pocket2     1   0.001850 -0.126576 -0.058374  0.070443 -0.