In [2]:
#import libraries
import pathlib
import torch
import esm
from esm import FastaBatchedDataset, pretrained

In [3]:
#downlaod esm2 model(ems2 model takes time while dowoloading, after that code takes few seconds to run) and embeddings etraction
def extract_embeddings(output_dir, tokens_per_batch=4096, seq_length=7096,repr_layers=[36]):
    
    model, alphabet = esm.pretrained.esm2_t36_3B_UR50D()
    model.eval()

    if torch.cuda.is_available():
        model = model.cuda()
        
    dataset = FastaBatchedDataset.from_file(fasta_file)
    batches = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)

    data_loader = torch.utils.data.DataLoader(
        dataset, 
        collate_fn=alphabet.get_batch_converter(seq_length), 
        batch_sampler=batches
    )

    output_dir.mkdir(parents=True, exist_ok=True)
    filenames = []  
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):

            print(f'Processing batch {batch_idx + 1} of {len(batches)}')

            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)

            logits = out["logits"].to(device="cpu")
            representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()}
            
            for i, label in enumerate(labels):
                entry_id = label.split()[0]
                filename = output_dir / f"{entry_id}.pt"
                filenames.append(filename)  
                truncate_len = min(seq_length, len(strs[i]))

                result = {"entry_id": entry_id}
                result["mean_representations"] = {
                        layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }

                torch.save(result, filename)
    return filenames  

In [None]:
#copy paste query sequence to query_sequence.txt file 
fasta_file = pathlib.Path('LLPS_Pred/query_sequences/query_sequence.txt')
output_dir = pathlib.Path('LLPS_Pred/embeddings/')
#run the function
extract_embeddings(fasta_file, output_dir)

In [4]:
import os
import re
import torch
import os
import re
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()

        self.conv1 = nn.Conv1d(in_channels=1, out_channels=512, kernel_size=3)
        self.batchnorm1 = nn.BatchNorm1d(512)
        self.pool = nn.MaxPool1d(kernel_size=4)
        self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3)
        self.batchnorm2 = nn.BatchNorm1d(256)

        self.fc_input_size = self.calculate_fc_input_size()
        self.fc1 = nn.Linear(self.fc_input_size, 128)
        self.batchnorm3 = nn.BatchNorm1d(128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x):
        x = x.view(-1, 1, 2560)

        x = self.pool(F.relu(self.batchnorm1(self.conv1(x))))
        x = self.pool(F.relu(self.batchnorm2(self.conv2(x))))

        self.fc_input_size = x.view(x.size(0), -1).size(1)
        x = x.view(-1, self.fc_input_size)

        x = F.relu(self.batchnorm3(self.fc1(x)))
        x = torch.sigmoid(self.fc2(x))

        return x
    
    def calculate_fc_input_size(self):
        x = torch.randn(1, 1, 2560)
        x = self.pool(F.relu(self.batchnorm1(self.conv1(x))))
        x = self.pool(F.relu(self.batchnorm2(self.conv2(x))))
        return x.view(x.size(0), -1).size(1)

# load embeddings
folder_path = 'LLPS_Pred/embeddings/'
files = sorted(os.listdir(folder_path))

def load_protein_representations():
    queryproteinrep = []
    for file_name in files:
        rep_changes = torch.load(os.path.join(folder_path, file_name))['mean_representations'][36]
        queryproteinrep.append(rep_changes.tolist())
    return torch.Tensor(queryproteinrep).unsqueeze(1)

query_rep = load_protein_representations()

# Define the custom dataset class
class ProteinDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

dataset = ProteinDataset(query_rep)

batch_size = 512
dataloader_val = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Load model
load_model_path = 'LLPS_Pred/model/model.pth'
loaded_model = torch.load(load_model_path)
loaded_model.eval()

predicted_labels = []
with torch.no_grad():
    for data_val, file_name in zip(dataloader_val, files):
        output_val = loaded_model(data_val)
        _, predicted = torch.max(output_val, 1)
        predicted_labels.extend(predicted.cpu().numpy())  

# pridicted results
for file_name, label in zip(files, predicted_labels):
    print(f"File: {file_name}, Prediction: {'LLPS' if label == 0 else 'non-LLPS'}")


File: example1.pt, Prediction: LLPS
File: example2.pt, Prediction: non-LLPS
File: example3.pt, Prediction: non-LLPS
