In [33]:
#importing libaries
import os
import re
import numpy as np
import pathlib
import torch
from esm import FastaBatchedDataset, pretrained
import xgboost as xgb

In [34]:
#extraction of embeddings for protein sequences
def extract_embeddings(model_name, fasta_file, output_dir, tokens_per_batch=4096, seq_length=8000,repr_layers=[36]):
    
    model, alphabet = pretrained.load_model_and_alphabet(model_name)
    model.eval()

    if torch.cuda.is_available():
        model = model.cuda()
        
    dataset = FastaBatchedDataset.from_file(fasta_file)
    batches = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)

    data_loader = torch.utils.data.DataLoader(
        dataset, 
        collate_fn=alphabet.get_batch_converter(seq_length), 
        batch_sampler=batches
    )

    output_dir.mkdir(parents=True, exist_ok=True)
    filenames = []  
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):

            print(f'Processing batch {batch_idx + 1} of {len(batches)}')

            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)

            logits = out["logits"].to(device="cpu")
            representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()}
            
            for i, label in enumerate(labels):
                entry_id = label.split()[0]
                filename = output_dir / f"{entry_id}.pt"
                filenames.append(filename)  
                truncate_len = min(seq_length, len(strs[i]))

                result = {"entry_id": entry_id}
                result["mean_representations"] = {
                        layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }

                torch.save(result, filename)
    return filenames 

In [40]:

# Define paths
model_path_with_extension = 'C:/Users/Ahmed/.cache/torch/hub/checkpoints/esm2_t36_3B_UR50D.pt'
model_path_without_extension = 'esm2_t36_3B_UR50D'

# Check if the esm2 model file exist
if os.path.exists(model_path_with_extension):
    model_name = model_path_with_extension
else:
    model_name = model_path_without_extension

# Define the other paths
fasta_file = pathlib.Path('Channel_Protein_Pred/input_sequence/input_sequence.fasta')
output_dir = pathlib.Path('Channel_Protein_Pred/embeddings/')

# Call the function
extract_embeddings(model_name, fasta_file, output_dir)


Processing batch 1 of 2
Processing batch 2 of 2


[WindowsPath('Channel_Protein_Pred/embeddings/1.pt'),
 WindowsPath('Channel_Protein_Pred/embeddings/3.pt'),
 WindowsPath('Channel_Protein_Pred/embeddings/2.pt'),
 WindowsPath('Channel_Protein_Pred/embeddings/4.pt')]

In [39]:
import os
import re
import xgboost as xgb
import numpy as np
import torch
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, roc_curve
import matplotlib.pyplot as plt

def load_protein_representations(folder_path, files):
    queryproteinrep = []
    for file_name in files:
        rep_changes = torch.load(os.path.join(folder_path, file_name))['mean_representations'][36]
        queryproteinrep.append(rep_changes.tolist())
    return torch.Tensor(queryproteinrep)

folder_path_test = 'Channel_Protein_Pred/embeddings/'
files_test = sorted(os.listdir(folder_path_test), key=lambda x: int(re.match(r'(\d+)', x).group(1)))
test_query_rep = load_protein_representations(folder_path_test, files_test)
test_data = test_query_rep.numpy()

# Load the model
loaded_model = xgb.XGBClassifier()
model_filename = 'Channel_Protein_Pred/model/xgboost_model.json'
loaded_model.load_model(model_filename)

# Make predictions 
test_predictions = loaded_model.predict(test_data)


for i, (header, prediction) in enumerate(zip(files_test, test_predictions), 1):
    channel_status = "Channel Protein" if prediction == 1 else "Non-Channel Protein"
    print(f"{i}\t{header}\t{channel_status:<20}")


1	1.pt	Non-Channel Protein 
2	2.pt	Non-Channel Protein 
3	3.pt	Channel Protein     
4	4.pt	Non-Channel Protein 
