# The final DpoDetection Tool :
***

In [1]:
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
from torch import nn 
import torch.nn.functional as F

import os
import numpy as np
import tqdm
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning) 

path_work = "/media/concha-eloko/Linux/depolymerase_building"

esm2_model_path = f"{path_work}/esm2_t12_35M_UR50D-finetuned-depolymerase.labels_4/checkpoint-6015"
DpoDetection_path = f"{path_work}/DepoDetection.T12.4Labels.1908.model"

tokenizer = AutoTokenizer.from_pretrained(esm2_model_path)
esm2_finetuned = AutoModelForTokenClassification.from_pretrained(esm2_model_path)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Dpo_classifier(nn.Module):
    def __init__(self, pretrained_model):
        super(Dpo_classifier, self).__init__()
        self.max_length = 1024
        self.pretrained_model = pretrained_model
        self.conv1 = nn.Conv1d(1, 64, kernel_size=5, stride=1)  # Convolutional layer
        self.conv2 = nn.Conv1d(64, 128, kernel_size=5, stride=1)  # Convolutional layer
        self.fc1 = nn.Linear(128 * (self.max_length - 2 * (5 - 1)), 32)  # calculate the output shape after 2 conv layers
        self.classifier = nn.Linear(32, 1)  # Binary classification

    def make_prediction(self, fasta_txt):
        input_ids = tokenizer.encode(fasta_txt, truncation=True, return_tensors='pt')
        with torch.no_grad():
            outputs = self.pretrained_model(input_ids)
            probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
            token_probs, token_ids = torch.max(probs, dim=-1)            
            tokens = token_ids.view(1, -1) # ensure 2D shape
            return tokens

    def pad_or_truncate(self, tokens):
        if tokens.size(1) < self.max_length:
            tokens = F.pad(tokens, (0, self.max_length - tokens.size(1)))
        elif tokens.size(1) > self.max_length:
            tokens = tokens[:, :self.max_length]
        return tokens

    def forward(self, sequences):
        batch_size = len(sequences)
        tokens_batch = []
        for seq in sequences:
            tokens = self.make_prediction(seq)
            tokens = self.pad_or_truncate(tokens)
            tokens_batch.append(tokens)
        
        outputs = torch.cat(tokens_batch).view(batch_size, 1, self.max_length)  # ensure 3D shape
        outputs = outputs.float()  # Convert to float
        
        out = F.relu(self.conv1(outputs))
        out = F.relu(self.conv2(out))
        out = out.view(batch_size, -1)  # Flatten the tensor
        out = F.relu(self.fc1(out))
        out = self.classifier(out)
        return out, outputs

In [3]:
model_classifier = Dpo_classifier(esm2_finetuned) # Create an instance of Dpo_classifier
model_classifier.load_state_dict(torch.load(DpoDetection_path), strict = False) # Load the saved weights ; weird Error with some of the keys 
model_classifier.eval() # Set the model to evaluation mode for inference


Dpo_classifier(
  (pretrained_model): EsmForTokenClassification(
    (esm): EsmModel(
      (embeddings): EsmEmbeddings(
        (word_embeddings): Embedding(33, 480, padding_idx=1)
        (dropout): Dropout(p=0.0, inplace=False)
        (position_embeddings): Embedding(1026, 480, padding_idx=1)
      )
      (encoder): EsmEncoder(
        (layer): ModuleList(
          (0-11): 12 x EsmLayer(
            (attention): EsmAttention(
              (self): EsmSelfAttention(
                (query): Linear(in_features=480, out_features=480, bias=True)
                (key): Linear(in_features=480, out_features=480, bias=True)
                (value): Linear(in_features=480, out_features=480, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
                (rotary_embeddings): RotaryEmbedding()
              )
              (output): EsmSelfOutput(
                (dense): Linear(in_features=480, out_features=480, bias=True)
                (dropout): Dropout(p=0.0, inpla

In [4]:
def predict_sequence(model, sequence):
    model.eval()
    with torch.no_grad():
        sequence = [sequence]  # Wrap the sequence in a list to match the model's input format
        outputs, sequence_outputs = model(sequence)
        probas = torch.sigmoid(outputs)  # Apply sigmoid activation for binary classification
        predictions = (probas > 0.5).float()  # Convert probabilities to binary predictions
        sequence_outputs_list = sequence_outputs.cpu().numpy().tolist()[0][0]
        prob_predicted = probas[0].item()
        return (predictions.item(), prob_predicted), sequence_outputs_list


def plot_token(tokens) :
    tokens = np.array(tokens)  # convert your list to numpy array for convenience
    plt.figure(figsize=(10,6))
    for i in range(len(tokens) - 1):
        if tokens[i] == 0:
            color = 'black'
        elif tokens[i] == 1:
            color = 'blue'
        elif tokens[i] == 2:
            color = 'red'
        else :
            color = 'green'
        plt.plot([i, i+1], [tokens[i], tokens[i+1]], color=color, marker='o')
    plt.xlabel('Token')
    plt.ylabel('Label')
    plt.title('Label for each token')
    plt.xticks(rotation='vertical')
    plt.yticks(np.arange(2), ['0', '1'])  
    plt.grid(True)
    plt.show()

***
# Predictions Ferriol

> Make predictions 

In [19]:
from Bio import SeqIO
from tqdm import tqdm 
from collections import Counter

path_out = "/media/concha-eloko/Linux/77_strains_phage_project"

prediction_results = {}
path_fasta = f"{path_out}/all_dpos.77_phages.multi.fasta"
fastas = SeqIO.parse(f"{path_fasta}" , "fasta")
tmp_results = []
for record in tqdm(fastas) :
    if len(record.seq) >= 200 :
        protein_seq = record.seq 
        prediction, sequence_outputs = predict_sequence(model_classifier, str(protein_seq))
        if record.description.count(",") == 0 :
            prot_id = record.description
        else :
            prot_id = "_".join(record.description.split(",")[0].split(" "))
            pass
        if prediction[0] == 1 :
            a = (prot_id , dict(Counter(sequence_outputs)))
            tmp_results.append(a)
        else :
            pass

132it [02:03,  1.07it/s]


In [20]:
tmp_results

[('K10PH82C1_cds_50', {0.0: 501, 1.0: 523}),
 ('K10PH82C1_cds_51', {0.0: 685, 1.0: 339}),
 ('K11PH164C1_cds_45', {0.0: 694, 1.0: 330}),
 ('K11PH164C1_cds_46', {0.0: 595, 1.0: 429}),
 ('K13PH07C1L_cds_10', {0.0: 626, 1.0: 398}),
 ('K13PH07C1L_cds_11', {0.0: 917, 1.0: 107}),
 ('K13PH07C1L_cds_12', {0.0: 810, 1.0: 214}),
 ('K13PH07C1S_cds_10', {0.0: 626, 1.0: 398}),
 ('K13PH07C1S_cds_11', {0.0: 648, 1.0: 376}),
 ('K14PH164C1_cds_24', {0.0: 534, 1.0: 490}),
 ('K15PH90_cds_55', {1.0: 543, 0.0: 481}),
 ('K16PH164C3_cds_48', {0.0: 534, 1.0: 490}),
 ('K17alfa61_cds_23', {0.0: 847, 1.0: 177}),
 ('K17alfa62_cds_64', {0.0: 603, 1.0: 421}),
 ('K17alfa62_cds_66', {0.0: 614, 1.0: 410}),
 ('K18PH07C1_cds_243', {0.0: 701, 1.0: 323}),
 ('K18PH07C1_cds_245', {0.0: 716, 1.0: 308}),
 ('K1PH164C1_cds_8', {0.0: 630, 1.0: 394}),
 ('K21lambda1_cds_28', {0.0: 815, 2.0: 209}),
 ('K22PH164C1_cds_10', {0.0: 661, 1.0: 363}),
 ('K22PH164C1_cds_11', {0.0: 688, 1.0: 336}),
 ('K23PH08C2_cds_233', {0.0: 677, 1.0: 347})

In [21]:
folds_label = {1.0 : "right-handed beta-helix", 2.0 : "6-bladed beta-propeller", 3.0 : "triple-helix"}
fold_dpoes = {}

for dpo in tmp_results :
    for label in dpo[1] : 
        if label in folds_label :
            fold = folds_label[label]
            fold_dpoes[dpo[0]] = fold
            break
fold_dpoes

{'K10PH82C1_cds_50': 'right-handed beta-helix',
 'K10PH82C1_cds_51': 'right-handed beta-helix',
 'K11PH164C1_cds_45': 'right-handed beta-helix',
 'K11PH164C1_cds_46': 'right-handed beta-helix',
 'K13PH07C1L_cds_10': 'right-handed beta-helix',
 'K13PH07C1L_cds_11': 'right-handed beta-helix',
 'K13PH07C1L_cds_12': 'right-handed beta-helix',
 'K13PH07C1S_cds_10': 'right-handed beta-helix',
 'K13PH07C1S_cds_11': 'right-handed beta-helix',
 'K14PH164C1_cds_24': 'right-handed beta-helix',
 'K15PH90_cds_55': 'right-handed beta-helix',
 'K16PH164C3_cds_48': 'right-handed beta-helix',
 'K17alfa61_cds_23': 'right-handed beta-helix',
 'K17alfa62_cds_64': 'right-handed beta-helix',
 'K17alfa62_cds_66': 'right-handed beta-helix',
 'K18PH07C1_cds_243': 'right-handed beta-helix',
 'K18PH07C1_cds_245': 'right-handed beta-helix',
 'K1PH164C1_cds_8': 'right-handed beta-helix',
 'K21lambda1_cds_28': '6-bladed beta-propeller',
 'K22PH164C1_cds_10': 'right-handed beta-helix',
 'K22PH164C1_cds_11': 'right-h

In [22]:
with open(f"/media/concha-eloko/Linux/PPT_clean/in_vitro/Celia/dpos_folds.celia.tsv", "w") as outfile : 
    outfile.write(f"protein_id\tFold\n")
    for protein,fold in fold_dpoes.items():
        outfile.write(f"{protein}\t{fold}\n")

> Save / Open predictions

In [None]:
import os
from tqdm import tqdm
from Bio import SeqIO

path_bea = "/media/concha-eloko/Linux/PPT_clean/in_vitro/Bea"

#dpos = set([prot_id[1] for file in prediction_results for prot_id in prediction_results[file]])

#with open("/media/concha-eloko/Linux/PPT_clean/in_vitro/Bea/DepoScope_predictions.tsv", "w") as outfile : 
#    for dpo in dpos :
#        outfile.write(dpo + "\n")

dpos = open("/media/concha-eloko/Linux/PPT_clean/in_vitro/Bea/DepoScope_predictions.tsv").read().split("\n")