# load model

In [None]:
import ast
import math
import pickle
from collections import Counter, deque
import warnings
import numpy as np
import pandas as pd
import torch
from sklearn.exceptions import UndefinedMetricWarning
from sklearn.metrics import (accuracy_score, average_precision_score, f1_score,
                             precision_score, recall_score, roc_auc_score)
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
from sklearn.preprocessing import MultiLabelBinarizer
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import networkx as nx
import obonet
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve
from sklearn.preprocessing import StandardScaler
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold, MultilabelStratifiedShuffleSplit
import random
import json

import sys

from utils_corrected import process_GO_data

In [None]:

class FFNN(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_rate=0.4):
        super(FFNN, self).__init__()
        layers = []
        dims = [input_dim] + hidden_dims
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            layers.append(nn.LeakyReLU())
        
        if dims[-1] != 2:  # Apply dropout only before the last output layer, if the last hidden layer is not 2D
            layers.append(nn.Dropout(dropout_rate))
        
        layers.append(nn.Linear(dims[-1], output_dim))  # Final layer (output layer)
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


In [None]:
input_dim=512
hidden_dims=[256]
output_dim=1098

In [4]:
model = FFNN(input_dim, hidden_dims, output_dim)

model.load_state_dict(torch.load("models/model_lr0.001_arch[256]_epochs20.pt"))

<All keys matched successfully>

In [6]:
from Bio import SeqIO

input_fasta = "/shared/projects/deepmar/data/cyanobacteriota/marine_unlabeled_data/cyanobact.original.long.prots.fasta"
records = list(SeqIO.parse(input_fasta, "fasta"))
filtered = [r for r in records if len(r.seq) <= 1200]

SeqIO.write(filtered, "/shared/projects/deepmar/data/cyanobacteriota/marine_unlabeled_data/cyanobact.original.long.prots.length_filtered.fasta", "fasta")

4282

In [11]:
%run /shared/projects/deepmar/PlasmoFP_public/src/generate_embeddings.py --input /shared/projects/deepmar/data/cyanobacteriota/marine_unlabeled_data/cyanobact.original.long.prots.length_filtered.fasta --output ./embeddings/ --output_format npz --tm_vec_model /shared/projects/deepmar/data/tmvec_model_weights/tm_vec_cath_model.ckpt  --device cuda  --tm_vec_config /shared/projects/deepmar/data/tmvec_model_weights/tm_vec_cath_model_params.json

Embedding generation started!
Output directory: embeddings
Log file: embeddings/embedding_generation.log
Verbose mode: OFF
Using device: cuda
2026-01-19 01:46:57,063 - INFO - Using device: cuda
Found 1 FASTA file(s) to process:
   1. cyanobact.original.long.prots.length_filtered.fasta

2026-01-19 01:46:57,085 - INFO - Found 1 FASTA file(s) to process
Loading ProtT5 tokenizer and model...
2026-01-19 01:46:57,085 - INFO - Loading ProtT5 tokenizer and model...


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Loading TM-Vec model...
2026-01-19 01:47:38,510 - INFO - Loading TM-Vec model...


Lightning automatically upgraded your loaded checkpoint from v1.5.8 to v2.6.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../../../../../projects/deepmar/data/tmvec_model_weights/tm_vec_cath_model.ckpt`


Models loaded successfully on device: cuda
2026-01-19 01:47:39,694 - INFO - Models loaded successfully on device: cuda
Starting processing...
[1/1] Processing: cyanobact.original.long.prots.length_filtered.fasta
Processing file: cyanobact.original.long.prots.length_filtered.fasta
2026-01-19 01:47:39,695 - INFO - Processing file: cyanobact.original.long.prots.length_filtered.fasta
Found 4282 sequences in cyanobact.original.long.prots.length_filtered.fasta
2026-01-19 01:47:39,711 - INFO - Found 4282 sequences in cyanobact.original.long.prots.length_filtered.fasta
Generating embeddings... (this may take a while)
2026-01-19 01:47:39,712 - INFO - Generating embeddings...


100%|██████████| 4282/4282 [14:54<00:00,  4.78it/s]


Successfully processed cyanobact.original.long.prots.length_filtered.fasta: 4282 sequences
Output files: cyanobact.original.long.prots.length_filtered_embeddings.npz, cyanobact.original.long.prots.length_filtered_embeddings.npy
Embedding shape: (4282, 512)
2026-01-19 02:02:35,102 - INFO - Successfully processed cyanobact.original.long.prots.length_filtered.fasta: 4282 sequences → cyanobact.original.long.prots.length_filtered_embeddings.npz, cyanobact.original.long.prots.length_filtered_embeddings.npy
2026-01-19 02:02:35,103 - INFO - Embedding shape: (4282, 512)
[1/1] Completed: cyanobact.original.long.prots.length_filtered.fasta
Progress: 1 success, 0 failed

FINAL RESULTS:
Successfully processed: 1 files
Failed to process: 0 files
Output directory: embeddings
2026-01-19 02:02:35,310 - INFO - Successfully processed: 1 files
2026-01-19 02:02:35,311 - INFO - Failed to process: 0 files
2026-01-19 02:02:35,313 - INFO - Output directory: embeddings


In [26]:
embeddings = np.load('embeddings/cyanobact.original.long.prots.length_filtered_embeddings.npy')
model.eval()

dataset = TensorDataset(torch.tensor(embeddings, dtype=torch.float))
loader = DataLoader(dataset, batch_size=64, shuffle=False)
device='cuda'

model.eval()
model.to("cuda")

all_preds = []

with torch.no_grad():
    for (x_batch,) in loader:
        x_batch = x_batch.to(device)
        outputs = model(x_batch)
        preds = torch.sigmoid(outputs).cpu().numpy()  # Apply sigmoid for probability outputs
        all_preds.append(preds)

# Stack predictions and targets into numpy arrays
predictions = np.vstack(all_preds)

array([[1.11863621e-07, 6.50834775e-10, 9.38488881e-07, ...,
        7.05698028e-07, 5.74189141e-09, 7.97522627e-03],
       [2.94914644e-05, 3.31585807e-06, 8.99189581e-06, ...,
        1.97517675e-05, 4.03580589e-06, 4.09856647e-01],
       [5.90996052e-09, 1.85903082e-09, 7.14630289e-07, ...,
        2.16212379e-06, 8.36399408e-07, 7.29877320e-06],
       ...,
       [2.17063589e-05, 4.41651810e-06, 2.02611591e-05, ...,
        2.24522315e-04, 2.99220659e-07, 3.67216882e-04],
       [2.35374114e-06, 6.64090976e-07, 4.42947248e-06, ...,
        1.02044782e-04, 2.75696550e-07, 1.83125274e-04],
       [1.18651449e-04, 8.91041127e-05, 2.81053162e-05, ...,
        2.03931704e-05, 2.93227276e-06, 1.08257365e-04]],
      shape=(4282, 1098), dtype=float32)

In [46]:
import pickle

with open("./data_mlb.pkl", "rb") as f:
    data_mlb = pickle.load(f)
labels = data_mlb.inverse_transform(predictions>0.0001)

all_go = data_mlb.classes_
all_go

array(['GO:0000030', 'GO:0000034', 'GO:0000035', ..., 'GO:1904680',
       'GO:1990107', 'GO:1990837'], shape=(1098,), dtype=object)

In [49]:
headers = np.load("embeddings/cyanobact.original.long.prots.length_filtered_embeddings.npz").files
df = pd.DataFrame(
    predictions,
    index=headers,
    columns=all_go
)
df

Unnamed: 0,GO:0000030,GO:0000034,GO:0000035,GO:0000036,GO:0000049,GO:0000104,GO:0000107,GO:0000155,GO:0000156,GO:0000166,...,GO:1901265,GO:1901363,GO:1901505,GO:1901681,GO:1901682,GO:1903425,GO:1904047,GO:1904680,GO:1990107,GO:1990837
CK_Pro_HNLC2_01521,1.118636e-07,6.508348e-10,9.384889e-07,3.007184e-06,3.721399e-05,1.112826e-08,9.780335e-09,1.593398e-06,0.000196,0.000176,...,0.000233,0.000227,0.000004,5.757060e-07,7.388383e-07,4.642738e-09,1.363363e-10,7.056980e-07,5.741891e-09,0.007975
CK_Syn_WH8016_01352,2.949146e-05,3.315858e-06,8.991896e-06,1.253963e-05,5.908323e-06,4.201046e-05,6.549231e-06,3.510659e-03,0.465769,0.005394,...,0.006673,0.026787,0.000040,9.006405e-05,4.284191e-04,3.488497e-06,1.137213e-05,1.975177e-05,4.035806e-06,0.409857
CK_Pro_MIT9314_00172,5.909961e-09,1.859031e-09,7.146303e-07,7.838213e-07,2.125682e-08,3.377305e-07,2.029289e-08,1.542352e-07,0.000002,1.000000,...,1.000000,1.000000,0.007943,6.631090e-06,7.285576e-02,9.785694e-08,7.222477e-06,2.162124e-06,8.363994e-07,0.000007
CK_Syn_BMK-MC-1_01498,1.493566e-08,3.657227e-11,2.247025e-06,3.485092e-06,5.160576e-07,4.315622e-06,5.483355e-08,6.752621e-07,0.000011,0.999995,...,0.999996,0.999998,0.000116,1.327090e-05,4.241196e-04,1.018060e-06,1.718478e-07,2.005316e-06,2.298000e-07,0.000055
CK_Syn_A18-40_02421,7.164195e-07,1.627143e-07,1.409644e-07,5.985796e-07,3.720277e-09,4.392184e-03,2.899120e-07,2.588477e-06,0.000002,0.000235,...,0.000230,0.000161,0.000010,1.088075e-05,3.781345e-05,1.112218e-04,9.605615e-08,5.904684e-06,5.806916e-09,0.000010
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
CK_Syn_WH8016_01676,9.724050e-05,1.193213e-05,1.179311e-06,5.688241e-07,3.415846e-09,1.591228e-05,2.383811e-05,5.716978e-06,0.000033,0.000239,...,0.000242,0.005447,0.000005,2.829412e-04,4.248186e-07,3.353662e-06,5.081999e-04,7.038080e-05,3.111782e-07,0.000033
CK_Cya_PCC7001_01923,3.451052e-05,3.004052e-05,7.284623e-06,1.702200e-06,6.080307e-05,5.663235e-06,6.078383e-07,3.409475e-08,0.000013,0.139590,...,0.137127,0.138064,0.000021,6.260543e-06,2.394915e-05,8.699118e-07,2.464591e-05,5.052889e-05,3.077695e-05,0.000005
CK_Pro_MIT0601_00730,2.170636e-05,4.416518e-06,2.026116e-05,2.314029e-05,1.950867e-07,1.049888e-06,9.984220e-05,2.263685e-06,0.000564,0.000108,...,0.000114,0.001645,0.000063,8.822869e-03,1.429854e-03,3.961058e-06,6.849683e-05,2.245223e-04,2.992207e-07,0.000367
CK_Syn_CC9311_00497,2.353741e-06,6.640910e-07,4.429472e-06,3.467045e-06,9.077188e-08,6.104728e-07,6.223776e-05,3.482233e-07,0.000055,0.000002,...,0.000002,0.000059,0.000135,9.447208e-04,8.299086e-04,1.386093e-06,2.161324e-05,1.020448e-04,2.756966e-07,0.000183


In [50]:
df.to_csv("cyanobact.original.predictions.tsv", sep="\t")

In [51]:
from Bio import SeqIO

input_fasta = "/shared/projects/deepmar/data/cyanobacteriota/marine_unlabeled_data/cyanobact.long.prots.shuffled.fasta"
records = list(SeqIO.parse(input_fasta, "fasta"))
filtered = [r for r in records if len(r.seq) <= 1200]

SeqIO.write(filtered, "/shared/projects/deepmar/data/cyanobacteriota/marine_unlabeled_data/cyanobact.long.prots.shuffled.length_filtered.fasta", "fasta")

%run /shared/projects/deepmar/PlasmoFP_public/src/generate_embeddings.py --input /shared/projects/deepmar/data/cyanobacteriota/marine_unlabeled_data/cyanobact.long.prots.shuffled.length_filtered.fasta --output ./embeddings/ --output_format npz --tm_vec_model /shared/projects/deepmar/data/tmvec_model_weights/tm_vec_cath_model.ckpt  --device cuda  --tm_vec_config /shared/projects/deepmar/data/tmvec_model_weights/tm_vec_cath_model_params.json

embeddings = np.load('embeddings/cyanobact.long.prots.shuffled.length_filtered_embeddings.npy')
model.eval()

dataset = TensorDataset(torch.tensor(embeddings, dtype=torch.float))
loader = DataLoader(dataset, batch_size=64, shuffle=False)
device='cuda'

model.eval()
model.to("cuda")

all_preds = []

with torch.no_grad():
    for (x_batch,) in loader:
        x_batch = x_batch.to(device)
        outputs = model(x_batch)
        preds = torch.sigmoid(outputs).cpu().numpy()  # Apply sigmoid for probability outputs
        all_preds.append(preds)

# Stack predictions and targets into numpy arrays
predictions = np.vstack(all_preds)

headers = np.load("embeddings/cyanobact.long.prots.shuffled.length_filtered_embeddings.npz").files
df = pd.DataFrame(
    predictions,
    index=headers,
    columns=all_go
)
df

df.to_csv("cyanobact.shuffled.predictions.tsv", sep="\t")

Embedding generation started!
Output directory: embeddings
Log file: embeddings/embedding_generation.log
Verbose mode: OFF
Using device: cuda
2026-01-19 02:22:48,518 - INFO - Using device: cuda
Found 1 FASTA file(s) to process:
   1. cyanobact.long.prots.shuffled.length_filtered.fasta

2026-01-19 02:22:48,518 - INFO - Found 1 FASTA file(s) to process
Loading ProtT5 tokenizer and model...
2026-01-19 02:22:48,519 - INFO - Loading ProtT5 tokenizer and model...
Loading TM-Vec model...
2026-01-19 02:22:53,717 - INFO - Loading TM-Vec model...


Lightning automatically upgraded your loaded checkpoint from v1.5.8 to v2.6.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../../../../../../../projects/deepmar/data/tmvec_model_weights/tm_vec_cath_model.ckpt`


Models loaded successfully on device: cuda
2026-01-19 02:22:54,042 - INFO - Models loaded successfully on device: cuda
Starting processing...
[1/1] Processing: cyanobact.long.prots.shuffled.length_filtered.fasta
Processing file: cyanobact.long.prots.shuffled.length_filtered.fasta
2026-01-19 02:22:54,043 - INFO - Processing file: cyanobact.long.prots.shuffled.length_filtered.fasta
Found 4282 sequences in cyanobact.long.prots.shuffled.length_filtered.fasta
2026-01-19 02:22:54,056 - INFO - Found 4282 sequences in cyanobact.long.prots.shuffled.length_filtered.fasta
Generating embeddings... (this may take a while)
2026-01-19 02:22:54,056 - INFO - Generating embeddings...


100%|██████████| 4282/4282 [14:47<00:00,  4.82it/s]


Successfully processed cyanobact.long.prots.shuffled.length_filtered.fasta: 4282 sequences
Output files: cyanobact.long.prots.shuffled.length_filtered_embeddings.npz, cyanobact.long.prots.shuffled.length_filtered_embeddings.npy
Embedding shape: (4282, 512)
2026-01-19 02:37:42,355 - INFO - Successfully processed cyanobact.long.prots.shuffled.length_filtered.fasta: 4282 sequences → cyanobact.long.prots.shuffled.length_filtered_embeddings.npz, cyanobact.long.prots.shuffled.length_filtered_embeddings.npy
2026-01-19 02:37:42,355 - INFO - Embedding shape: (4282, 512)
[1/1] Completed: cyanobact.long.prots.shuffled.length_filtered.fasta
Progress: 1 success, 0 failed

FINAL RESULTS:
Successfully processed: 1 files
Failed to process: 0 files
Output directory: embeddings
2026-01-19 02:37:42,602 - INFO - Successfully processed: 1 files
2026-01-19 02:37:42,603 - INFO - Failed to process: 0 files
2026-01-19 02:37:42,605 - INFO - Output directory: embeddings
