In [30]:
from pathlib import Path
import os
import sys
import gzip
import wget
from transformers import AutoTokenizer, AutoModel
import torch
import pandas as pd

curdir = Path(os.getcwd())
sys.path.append(str(curdir.parent.absolute()))

from src.models.ProTCL import ProTCL
from src.utils.models import generate_label_embeddings_from_text
from src.models.protein_encoders import ProteInfer

%load_ext autoreload
%autoreload 2


from src.utils.data import load_model_weights, read_pickle

# Define paths
model_path = '/home/ncorley/protein/ProteinFunctions/models/ProTCL/2023-09-25_22-13-48_ProTCL.pt'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
# Initialize ProteInfer
sequence_encoder = ProteInfer.from_pretrained(
    weights_path='/home/ncorley/protein/ProteinFunctions/models/proteinfer/GO_model_weights.pkl',
    num_labels=32102,
    input_channels=20,
    output_channels=1100,
    kernel_size=9,
    activation=torch.nn.ReLU,
    dilation_base=3,
    num_resnet_blocks=5,
    bottleneck_factor=0.5,
)

In [13]:
# Initialize label tokenizer
label_tokenizer = AutoTokenizer.from_pretrained("microsoft/biogpt")

# Load label encoder
label_encoder = AutoModel.from_pretrained("microsoft/biogpt")

# Initialize the model
model = ProTCL(
    label_encoder=label_encoder,
    sequence_encoder=sequence_encoder,
).to('cuda')
load_model_weights(model, model_path)


In [14]:
# Encode the text
text = "Transmembrane protein found in prokaryotes involved in cellular signalling"
label_embeddings = generate_label_embeddings_from_text(text, label_tokenizer, label_encoder, 300)[0]

# Add a dimension to the label embeddings tensor
label_embeddings = label_embeddings.unsqueeze(0)

In [15]:
import torch
import numpy as np

sequence_embedding_path = '/home/ncorley/protein/ProteinFunctions/data/embeddings/proteinfer/frozen_proteinfer_sequence_embeddings.pkl'
sequence_embedding_dict = read_pickle(sequence_embedding_path)

# Convert sequence ids to a list
sequence_ids_list = list(sequence_embedding_dict.keys())

# Convert sequence embeddings to a list of tensors
sequence_embeddings_list = [torch.tensor(embedding) for embedding in sequence_embedding_dict.values()]


In [16]:
model.eval()

BATCH_SIZE = 2000

similarity_scores = []

label_embeddings = label_embeddings.to('cuda')

with torch.no_grad():
    # Loop through batches of sequences
    for i in range(0, len(sequence_embeddings_list), BATCH_SIZE):
        batch_sequence_embeddings = torch.stack(sequence_embeddings_list[i:i+BATCH_SIZE]).to('cuda')
        logits = model(sequence_embeddings=batch_sequence_embeddings, label_embeddings=label_embeddings)
        
        # Extend the similarity scores list with the current batch results
        for j, logit in enumerate(logits):
            similarity_scores.append((sequence_ids_list[i+j], logit))

        # Print progress every 5 batches
        if i % (40 * BATCH_SIZE) == 0:
            print(f"Processed {i} out of {len(sequence_embeddings_list)} sequences")
    
    # Sort the similarity scores based on logits in descending order
  
    # similarity_scores.sort(key=lambda x: x[1], reverse=True)


Processed 0 out of 522607 sequences
Processed 80000 out of 522607 sequences
Processed 160000 out of 522607 sequences
Processed 240000 out of 522607 sequences
Processed 320000 out of 522607 sequences
Processed 400000 out of 522607 sequences
Processed 480000 out of 522607 sequences


In [17]:
import numpy as np

print("Sorting similarity scores...")

# Convert the similarity_scores list to a structured numpy array
dtype = [('sequence_id', 'U50'), ('logit', float)]  # Assuming sequence_id is a string of max length 50. Adjust as needed.
similarity_array = np.array(similarity_scores, dtype=dtype)

# Sort the array by the 'logit' field
sorted_array = np.sort(similarity_array, order='logit')[::-1]  # [::-1] to sort in descending order

# Convert back to a list if needed
sorted_scores = list(map(tuple, sorted_array))


Sorting similarity scores...


In [33]:
top_30 = sorted_scores[:30]
top_30

[('Q7VIA0', 2.557553291320801),
 ('A8G311', 2.5293922424316406),
 ('Q2RTH2', 2.482395887374878),
 ('A3PB74', 2.2203664779663086),
 ('B3EMS7', 2.001722812652588),
 ('A2BPF4', 1.962784767150879),
 ('A7GWS4', 1.9538805484771729),
 ('Q5N2J3', 1.7185605764389038),
 ('Q31RR1', 1.7185605764389038),
 ('Q9RJ68', 1.5750106573104858),
 ('A7ZF22', 1.541426420211792),
 ('Q8NQE4', 1.4177640676498413),
 ('Q828H7', 1.4150567054748535),
 ('Q1CUI2', 1.4137609004974365),
 ('B2USE9', 1.371057391166687),
 ('O25088', 1.3691730499267578),
 ('Q9ZMB8', 1.3100190162658691),
 ('B6JKP6', 1.3015974760055542),
 ('A4TB40', 1.2742552757263184),
 ('Q17X55', 1.2611216306686401),
 ('Q98LC5', 1.2023166418075562),
 ('Q8FTE9', 1.1643602848052979),
 ('B1W317', 1.1350586414337158),
 ('C3MBT9', 1.0523419380187988),
 ('B4SE06', 0.9217094779014587),
 ('A8FMN2', 0.9044288992881775),
 ('A5CXC5', 0.8850680589675903),
 ('Q8KC14', 0.8686895966529846),
 ('Q8UEP9', 0.852597177028656),
 ('Q9PNB9', 0.7877622842788696)]

In [22]:
# Load the swissprot dataframe
df_2023 = read_pickle('/home/ncorley/protein/ProteinFunctions/data/swissprot/swissprot_2023.pkl')
df_2023.head()

Unnamed: 0,seq_id,sequence,go_ids,description,organism,organism_classification,organelle,cc,subcellular_location
0,Q6GZX4,MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQV...,[GO:0046782],RecName: Full=Putative transcription factor 001R;,Frog virus 3 (isolate Goorha) (FV-3).,"[Viruses, Varidnaviria, Bamfordvirae, Nucleocy...",,{'FUNCTION': 'Transcription activation. {ECO:0...,
1,Q6GZX3,MSIIGATRLQNDKSDTYSAGPCYAGGCSAFTPRGTCGKDWDLGEQT...,"[GO:0033644, GO:0016020]",RecName: Full=Uncharacterized protein 002L;,Frog virus 3 (isolate Goorha) (FV-3).,"[Viruses, Varidnaviria, Bamfordvirae, Nucleocy...",,{'SUBCELLULAR LOCATION': 'Host membrane {ECO:0...,Host membrane {ECO:0000305}; Single-pass membr...
2,Q197F8,MASNTVSAQGGSNRPVRDFSNIQDVAQFLLFDPIWNEQPGSIVPWK...,[],RecName: Full=Uncharacterized protein 002R;,Invertebrate iridescent virus 3 (IIV-3) (Mosqu...,"[Viruses, Varidnaviria, Bamfordvirae, Nucleocy...",,{},
3,Q197F7,MYQAINPCPQSWYGSPQLEREIVCKMSGAPHYPNYYPVHPNALGGA...,[],RecName: Full=Uncharacterized protein 003L;,Invertebrate iridescent virus 3 (IIV-3) (Mosqu...,"[Viruses, Varidnaviria, Bamfordvirae, Nucleocy...",,{},
4,Q6GZX2,MARPLLGKTSSVRRRLESLSACSIFFFLRKFCQKMASLVFLNSPVY...,[],RecName: Full=Uncharacterized protein 3R; Flag...,Frog virus 3 (isolate Goorha) (FV-3).,"[Viruses, Varidnaviria, Bamfordvirae, Nucleocy...",,{},


In [34]:
# Create a new dataframe with the top 10 sequence ids and their scores. 
df_top_30 = pd.DataFrame(top_30, columns=['seq_id', 'score'])

# Then merge with the swissprot dataframe on the seq_id column to get the the additional information
df_top_30 = df_top_30.merge(df_2023, on='seq_id')


In [36]:
df_top_30.head()
# Save the dataframe to a csv
df_top_30.to_csv('/home/ncorley/protein/ProteinFunctions/data/swissprot/top_30.csv', index=False)