In [1]:
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset
import pytorch_lightning as pl
from pytorch_lightning.plugins import DDPPlugin

from embed_structure_model import trans_basic_block, trans_basic_block_Config
from tm_vec_utils import featurize_prottrans, embed_tm_vec

from transformers import T5EncoderModel, T5Tokenizer
import re
import gc

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns 

from Bio import SeqIO
import gzip

import faiss

In [2]:
#Load the ProtTrans model and ProtTrans tokenizer
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50", do_lower_case=False )
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_uniref50")
gc.collect()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.eval()

Some weights of the model checkpoint at Rostlab/prot_t5_xl_uniref50 were not used when initializing T5EncoderModel: ['decoder.block.16.layer.1.EncDecAttention.o.weight', 'decoder.block.5.layer.1.EncDecAttention.k.weight', 'decoder.block.11.layer.0.SelfAttention.o.weight', 'decoder.block.20.layer.1.layer_norm.weight', 'decoder.block.12.layer.0.SelfAttention.k.weight', 'decoder.block.12.layer.1.EncDecAttention.o.weight', 'decoder.block.18.layer.1.EncDecAttention.v.weight', 'decoder.block.5.layer.1.EncDecAttention.q.weight', 'decoder.block.21.layer.2.DenseReluDense.wi.weight', 'decoder.block.21.layer.1.EncDecAttention.q.weight', 'decoder.block.22.layer.0.SelfAttention.k.weight', 'decoder.block.5.layer.0.SelfAttention.q.weight', 'decoder.block.23.layer.1.EncDecAttention.o.weight', 'decoder.block.0.layer.1.layer_norm.weight', 'decoder.block.2.layer.0.SelfAttention.k.weight', 'decoder.block.5.layer.2.layer_norm.weight', 'decoder.block.21.layer.0.SelfAttention.k.weight', 'decoder.block.23.lay

In [3]:
#TM-Vec model paths

tm_vec_model_cpnt = "/mnt/home/thamamsy/ceph/deepblast/models/transformer_lr0.0001_dmodel1024_nlayer2_datasample_45_thresh_300_pairs_in_folds_included_23M_normal_tmax/checkpoints/last.ckpt"
tm_vec_model_config = "/mnt/home/thamamsy/ceph/deepblast/models/transformer_lr0.0001_dmodel1024_nlayer2_datasample_45_thresh_300_pairs_in_folds_included_23M_normal_tmax/params.json"
#tm_vec_model_cpnt = '/mnt/home/thamamsy/ceph/deepblast/models/transformer_lr0.000075_dmodel1024_nlayer2_all_pairs_tm_sample_95percent_141Mtest/checkpoints/epoch=3-step=1490999-val_loss=0.0272.ckpt'
#tm_vec_model_config = '/mnt/home/thamamsy/ceph/deepblast/models/transformer_lr0.000075_dmodel1024_nlayer2_all_pairs_tm_sample_95percent_141Mtest/params.json'

#Load the TM-Vec model
tm_vec_model_config = trans_basic_block_Config.from_json(tm_vec_model_config)
model_deep = trans_basic_block.load_from_checkpoint(tm_vec_model_cpnt, config=tm_vec_model_config)
model_deep = model_deep.to(device)
model_deep = model_deep.eval()

In [4]:
#Load some example sequences- in this case Bacteriocins
sequence_file = pd.read_csv("/mnt/home/thamamsy/ceph/deepblast/data/other_benchmarks/bagel_bacteriocins_class_1_unique.csv")
sequence_file['length'] = [len(bact) for bact in sequence_file['Sequence'].values]
#Filter for sequences that meet some criteria- in this case sequences that are longer than 30 residues
sequence_file_longer_than_30 = sequence_file[sequence_file['length'] >= 30]

#Make a list of your sequences
flat_seqs = list(sequence_file_longer_than_30['Sequence'].values)

In [5]:
#Embed query sequences
i = 0
embed_all_sequences=[]
while i < len(flat_seqs): 
    protrans_sequence = featurize_prottrans(flat_seqs[i:i+1], model, tokenizer, device)
    embedded_sequence = embed_tm_vec(protrans_sequence, model_deep, device)
    embed_all_sequences.append(embedded_sequence)
    i = i + 1

In [6]:
#convert query embeddings into a numpy array
queries = np.concatenate(embed_all_sequences, axis=0)

In [7]:
#Normalize queries 
faiss.normalize_L2(queries)

In [8]:
#Load the database that we will query
#Make sure that the query database was encoded using the same model that's being applied to the query (i.e. CATH and CATH database, versus SWISS and SWISS database)
query_database = np.load("/mnt/home/thamamsy/ceph/deepblast/data/embeddings_cath_s100_final.npy")
metadata_database = np.load("/mnt/home/thamamsy/ceph/deepblast/data/embeddings_cath_s100_ids.npy")

In [9]:
#Build an indexed database
d = query_database.shape[1] 
index = faiss.IndexFlatIP(d)
faiss.normalize_L2(query_database)
index.add(query_database)              

In [10]:
#Return the k nearest neighbors
k = 5
D, I = index.search(queries, k)

In [11]:
print("TM scores for the nearest neighbors")
D

TM scores for the nearest neighbors


array([[0.5717027 , 0.56875336, 0.56745553, 0.5644287 , 0.55858564],
       [0.63027215, 0.6069225 , 0.6069196 , 0.6052871 , 0.6016288 ],
       [0.5485954 , 0.5453853 , 0.5337225 , 0.5337225 , 0.5247935 ],
       ...,
       [0.5482549 , 0.5482549 , 0.548123  , 0.54129726, 0.539116  ],
       [0.61546576, 0.6048607 , 0.5763045 , 0.5763045 , 0.5763045 ],
       [0.54362416, 0.53372455, 0.53309476, 0.53309476, 0.53239965]],
      dtype=float32)

In [12]:
#Get metadata for the nearest neighbors
near_ids = []
for i in range(I.shape[0]):
    meta = metadata_database[I[i]]
    near_ids.append(list(meta))       

near_ids = np.array(near_ids)

In [13]:
print("Metadata for the nearest neighbors")
near_ids

Metadata for the nearest neighbors


array([['cath|4_3_0|1q16A02/28-40', 'cath|4_3_0|1pvc400/2-69',
        'cath|4_3_0|2fomA01/61-71', 'cath|4_3_0|1z7s400/2-69',
        'cath|4_3_0|2xzmW04/248-260'],
       ['cath|4_3_0|4ef8A02/54-71_197-222', 'cath|4_3_0|1lqlA01/4-29',
        'cath|4_3_0|5o7oC02/323-370', 'cath|4_3_0|2wadA01/50-56_280-302',
        'cath|4_3_0|3g2mB01/12-18_174-233'],
       ['cath|4_3_0|1q16A02/28-40', 'cath|4_3_0|2fomA01/61-71',
        'cath|4_3_0|4bpeW04/248-260', 'cath|4_3_0|2xzmW04/248-260',
        'cath|4_3_0|2z5bB02/81-93'],
       ...,
       ['cath|4_3_0|2xzmW04/248-260', 'cath|4_3_0|4bpeW04/248-260',
        'cath|4_3_0|2z5bB02/81-93', 'cath|4_3_0|3rf9A03/361-374',
        'cath|4_3_0|3u5cE04/246-261'],
       ['cath|4_3_0|1b8xA03/213-260', 'cath|4_3_0|1ev1400/2-69',
        'cath|4_3_0|3oixB02/55-72_196-219',
        'cath|4_3_0|3oixC02/55-72_196-219',
        'cath|4_3_0|3oixA02/55-72_196-219'],
       ['cath|4_3_0|1q16A02/28-40', 'cath|4_3_0|2virC01/43-56_273-309',
        'cath|4_3_0|2