In [None]:
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
from torch import nn 
import torch.nn.functional as F

import os
import numpy as np
import pandas as pd 
from tqdm import tqdm
from Bio import SeqIO
from collections import Counter, defaultdict
from multiprocessing.pool import ThreadPool
from concurrent.futures import ProcessPoolExecutor
import warnings
warnings.filterwarnings("ignore") 


path_fasta = f"/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/15122022_session/part_III_ptA/input_db/all_prophage_proteins.db.fasta"
path_work = f"/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"


# *********************************************************************
# Define and load DepoScope :
esm2_model_path = f"/home/conchae/PhageDepo_pdb/script_files/esm2_t12_35M_UR50D-finetuned-depolymerase.labels_4/checkpoint-6015"
DpoDetection_path = f"/home/conchae/PhageDepo_pdb/DepoDetection.T12.4Labels.1908.model"

tokenizer = AutoTokenizer.from_pretrained(esm2_model_path)
esm2_finetuned = AutoModelForTokenClassification.from_pretrained(esm2_model_path)

class Dpo_classifier(nn.Module):
    def __init__(self, pretrained_model):
        super(Dpo_classifier, self).__init__()
        self.max_length = 1024
        self.pretrained_model = pretrained_model
        self.conv1 = nn.Conv1d(1, 64, kernel_size=5, stride=1)  # Convolutional layer
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, stride=1)  # Convolutional layer
        self.fc1 = nn.Linear(128 * (self.max_length - 2 * (5 - 1)), 32)  # calculate the output shape after 2 conv layers
        self.classifier = nn.Linear(32, 1)  # Binary classification

    def make_prediction(self, fasta_txt):
        input_ids = tokenizer.encode(fasta_txt, truncation=True, return_tensors='pt')
        with torch.no_grad():
            outputs = self.pretrained_model(input_ids)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
            token_probs, token_ids = torch.max(probs, dim=-1)            
            tokens = token_ids.view(1, -1) # ensure 2D shape
            return tokens

    def pad_or_truncate(self, tokens):
        if tokens.size(1) < self.max_length:
            tokens = F.pad(tokens, (0, self.max_length - tokens.size(1)))
        elif tokens.size(1) > self.max_length:
            tokens = tokens[:, :self.max_length]
        return tokens

    def forward(self, sequences):
        batch_size = len(sequences)
        tokens_batch = []
        for seq in sequences:
            tokens = self.make_prediction(seq)
            tokens = self.pad_or_truncate(tokens)
            tokens_batch.append(tokens)
        
        outputs = torch.cat(tokens_batch).view(batch_size, 1, self.max_length)  # ensure 3D shape
        outputs = outputs.float()  # Convert to float
        
        out = F.relu(self.conv1(outputs))
        out = F.relu(self.conv2(out))
        out = out.view(batch_size, -1)  # Flatten the tensor
        out = F.relu(self.fc1(out))
        out = self.classifier(out)
        return out, outputs

model_classifier = Dpo_classifier(esm2_finetuned) # Create an instance of Dpo_classifier
model_classifier.load_state_dict(torch.load(DpoDetection_path), strict = False) # Load the saved weights ; weird Error with some of the keys 
model_classifier.eval() # Set the model to evaluation mode for inference


# *********************************************************************
# Useful functions :

def predict_sequence(model, sequence):
    model.eval()
    with torch.no_grad():
        sequence = [sequence]  # Wrap the sequence in a list to match the model's input format
        outputs, sequence_outputs = model(sequence)
        probas = torch.sigmoid(outputs)  # Apply sigmoid activation for binary classification
        predictions = (probas > 0.5).float()  # Convert probabilities to binary predictions
        sequence_outputs_list = sequence_outputs.cpu().numpy().tolist()[0][0]
        prob_predicted = probas[0].item()
        return (predictions.item(), prob_predicted), sequence_outputs_list


def find_longest_non_zero_suite_with_n_zeros(lst, n):
    # Initialize variables to keep track of the longest suite
    longest_start, longest_end = 0, 0
    longest_length = 0
    # Initialize variables to keep track of the current suite
    current_start = 0
    current_length = 0
    current_zeros = 0
    for i, num in enumerate(lst):
        if num == 0:
            # Increment the count of zeros in the current suite
            current_zeros += 1
            # If the number of zeros exceeds n, update the current start index and length
            while current_zeros > n:
                if lst[current_start] == 0:
                    current_zeros -= 1
                current_start += 1
                current_length -= 1
        # Increment the length of the current suite
        current_length += 1
        # Check if the current suite is longer than the longest suite found so far
        if current_length > longest_length:
            longest_start = current_start
            longest_end = i
            longest_length = current_length
    return longest_start, longest_end
    

# *********************************************************************
# Load the sequences into a dictionary :
fasta_seqs = SeqIO.parse(path_fasta , "fasta")
dico_seq = defaultdict(list)
for record in fasta_seqs:
    tmp_prot_name = record.id
    sequence = str(record.seq)
    if len(sequence) >= 200 :
        dico_seq[sequence].append(tmp_prot_name)


# *********************************************************************
# Make the predictions : 
def run_predictions(item) : 
    sequence , prot_names = item[0] , item[1]
    prediction, sequence_outputs = predict_sequence(model_classifier, sequence)
    if prediction[0] == 1 :
        start , end = find_longest_non_zero_suite_with_n_zeros(sequence_outputs, 10)
        with open(f"{path_work}/Anubis_return.predictions.0709.tsv" , "w") as outfile :
            for _,prot in enumerate(prot_names) :
                outfile.write(f"{prot}\t{start}\t{end}\t{sequence[int(start) : int(end)]}\t{sequence}\n")

if __name__ == '__main__':
    RUN = list(map(run_predictions , list(dico_seq.items())))