In [1]:
#import libraries
import pathlib
import torch
import esm
from esm import pretrained
from esm import FastaBatchedDataset
from tqdm import tqdm

In [2]:
def extract_embeddings(output_dir, fasta_file, tokens_per_batch=4096, seq_length=7096, repr_layers=[36]):
    model, alphabet = pretrained.esm2_t36_3B_UR50D()
    model.eval()

    if torch.cuda.is_available():
        model = model.cuda()
        
    dataset = FastaBatchedDataset.from_file(fasta_file)
    batches = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)

    data_loader = torch.utils.data.DataLoader(
        dataset, 
        collate_fn=alphabet.get_batch_converter(seq_length), 
        batch_sampler=batches
    )

    output_dir.mkdir(parents=True, exist_ok=True)
    filenames = []  
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in tqdm(enumerate(data_loader), total=len(batches)):
            print(f'Processing batch {batch_idx + 1} of {len(batches)}')

            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)

            logits = out["logits"].to(device="cpu")
            representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()}
            
            for i, label in enumerate(labels):
                entry_id = label.split()[0]
                filename = output_dir / f"{entry_id}.pt"
                filenames.append(filename)  
                truncate_len = min(seq_length, len(strs[i]))

                result = {"entry_id": entry_id}
                result["mean_representations"] = {
                        layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }

                torch.save(result, filename)
    return filenames  

In [None]:
import pathlib

# Start from the current directory
current_path = pathlib.Path().resolve()


# Traverse upwards until we find 'LLPS_regulators_pred'
while current_path.name != "LLPS_regulators_pred":
    if current_path.parent == current_path:
        raise FileNotFoundError("Project root 'LLPS_regulators_pred' not found in path hierarchy.")
    current_path = current_path.parent

project_root = current_path
fasta_file = project_root / "Input_sequences" / "Input_sequences.txt"
output_dir = project_root / "embeddings"

# Run the function
extract_embeddings(output_dir, fasta_file)

In [4]:
import re, os
import pandas as pd
import numpy as np
import random
import glob
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score, roc_curve
import torch
from tensorflow.keras.models import load_model


# Load the query sequence representations
def load_protein_representations(folder_path, files):
    queryproteinrep = []
    for file_name in files:
        file_path = os.path.join(folder_path, file_name)
        if os.path.exists(file_path):
            rep_changes = torch.load(file_path)['mean_representations'][36]
            queryproteinrep.append(rep_changes.tolist())
        else:
            print(f"File {file_path} not found.")
    return torch.tensor(queryproteinrep)



# Automatically detect project root
current_path = pathlib.Path().resolve()
while current_path.name != "LLPS_regulators_pred":
    if current_path.parent == current_path:
        raise FileNotFoundError("Project root 'LLPS_regulators_pred' not found in path hierarchy.")
    current_path = current_path.parent

project_root = current_path

# Path to sequence representations
folder_path = project_root / "embeddings"
files_test = sorted(os.listdir(folder_path))
query_rep = load_protein_representations(folder_path, files_test)
query_rep = query_rep.numpy()

def create_model():
    model = Sequential([
        Dense(128, activation='relu', input_shape=(2560,)),
        Dense(64, activation='relu'),
        Dense(32, activation='relu'),
        Dense(16, activation='relu'),
        Dense(1, activation='sigmoid')  
    ])

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

    model.compile(optimizer=optimizer,
                    loss='binary_crossentropy',
                    metrics=['accuracy'])
    return model

model = create_model()


# Load all saved models for ensemble predictions
import glob

# Path to models directory using project_root
models_path = project_root / "models"
model_files = sorted(models_path.glob("dataset_*.h5"))

# Convert Path objects to strings for compatibility if needed
#model_files = [str(file) for file in model_files]

# Initialize lists for ensemble predictions
loaded_predictions = []
loaded_probabilities = []


for model_file in model_files:
    loaded_model = load_model(model_file, compile = False)
    # Make predictions on the test set
    y_pred_proba = loaded_model.predict(query_rep, verbose=0)
    y_pred = (y_pred_proba > 0.5).astype(int)
    loaded_predictions.append(y_pred)
    loaded_probabilities.append(y_pred_proba)
    
# Convert predictions and probabilities to numpy arrays
ensemble_predictions = np.array(loaded_predictions)  
ensemble_probabilities = np.array(loaded_probabilities) 

# Using the model to make predictions
predictions = model.predict(query_rep)
predicted_labels = (predictions > 0.5).astype(int)

# After all folds are processed, calculate the final ensemble accuracy and other metrics
ensemble_predictions = np.array(ensemble_predictions)  
ensemble_probabilities = np.array(ensemble_probabilities) 

# Majority voting for final predictions
votes = np.sum(ensemble_predictions, axis=0)  
majority_decision = (votes > (ensemble_predictions.shape[0] // 2)).astype(int)

# Handle ties (if any)
ties = (votes == ensemble_predictions.shape[0] // 2)  
if np.any(ties):
    avg_probabilities = ensemble_probabilities.mean(axis=0)
    majority_decision[ties] = (avg_probabilities[ties] >= 0.5).astype(int)

# Final ensemble predictions
final_predictions = majority_decision
final_probabilities = ensemble_probabilities.mean(axis=0)

for final_prediction in final_predictions:
    if final_prediction == 1:
        print("Regulator in LLPS")
    if final_prediction == 0:
        print("Non-Regulator in LLPS")

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 47ms/step
Regulator in LLPS
Non-Regulator in LLPS
