In [34]:
#import libraries
import pathlib
import torch
import esm
from esm import pretrained
from esm import FastaBatchedDataset
from tqdm import tqdm

In [35]:
def extract_embeddings(output_dir, fasta_file, tokens_per_batch=4096, seq_length=7096, repr_layers=[36]):
    model, alphabet = pretrained.esm2_t36_3B_UR50D()
    model.eval()

    if torch.cuda.is_available():
        model = model.cuda()
        
    dataset = FastaBatchedDataset.from_file(fasta_file)
    batches = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)

    data_loader = torch.utils.data.DataLoader(
        dataset, 
        collate_fn=alphabet.get_batch_converter(seq_length), 
        batch_sampler=batches
    )

    output_dir.mkdir(parents=True, exist_ok=True)
    filenames = []  
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in tqdm(enumerate(data_loader), total=len(batches)):
            print(f'Processing batch {batch_idx + 1} of {len(batches)}')

            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)

            logits = out["logits"].to(device="cpu")
            representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()}
            
            for i, label in enumerate(labels):
                entry_id = label.split()[0]
                filename = output_dir / f"{entry_id}.pt"
                filenames.append(filename)  
                truncate_len = min(seq_length, len(strs[i]))

                result = {"entry_id": entry_id}
                result["mean_representations"] = {
                        layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }

                torch.save(result, filename)
    return filenames  

In [36]:
# Copy paste query sequence to query_sequence.txt file 
fasta_file = pathlib.Path('Channel_Proteins_Pred2-main/Input_sequences/Input_sequences.txt')
output_dir = pathlib.Path('Channel_Proteins_Pred2-main/embeddings/')

# Run the function
extract_embeddings(output_dir, fasta_file)

  0%|          | 0/1 [00:00<?, ?it/s]

Processing batch 1 of 1


100%|██████████| 1/1 [00:11<00:00, 11.90s/it]


[WindowsPath('Channel_Proteins_Pred2-main/embeddings/example2.pt'),
 WindowsPath('Channel_Proteins_Pred2-main/embeddings/example1.pt')]

In [37]:
import re, os
import pandas as pd
import numpy as np
import random
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score, roc_curve

# Load the query sequence representations
def load_protein_representations(folder_path, files):
    queryproteinrep = []
    for file_name in files:
        file_path = os.path.join(folder_path, file_name)
        if os.path.exists(file_path):
            rep_changes = torch.load(file_path)['mean_representations'][36]
            queryproteinrep.append(rep_changes.tolist())
        else:
            print(f"File {file_path} not found.")
    return torch.tensor(queryproteinrep)

# path to sequence representations
folder_path = 'Channel_Proteins_Pred2-main/embeddings/'
files_test = sorted(os.listdir(folder_path)) 

query_rep = load_protein_representations(folder_path, files_test)
query_rep = query_rep.numpy()

# neural network model
def create_model():
    model = Sequential([
        Dense(128, activation='relu', input_shape=(2560,)),
        Dense(64, activation='relu'),
        Dense(32, activation='relu'),
        Dense(16, activation='relu'),
        Dense(1, activation='sigmoid') 
    ])

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

    # Compile the model with the optimizer
    model.compile(optimizer=optimizer,
                  loss='binary_crossentropy', 
                  metrics=['accuracy'])
    return model

model = create_model()

# Load the model 
best_weights_path = 'Channel_Proteins_Pred2-main/model/model.h5' 
model.load_weights(best_weights_path)

# Using the model to make predictions
predictions = model.predict(query_rep)
predicted_labels = (predictions > 0.5).astype(int)

for label in predicted_labels:
    if label == 0:
        print("Non-channel protein")
    if label == 1:
        print("channel protein")

channel protein
Non-channel protein
