In [None]:
from tarp.model.backbone.untrained.lstm import LstmEncoder
from tarp.model.backbone.untrained.hyena import HyenaEncoder
from tarp.model.finetuning.classification import ClassificationModel

from tarp.services.datasets.classification.multilabel import MultiLabelClassificationDataset
from tarp.services.tokenizers.pretrained.dnabert import Dnabert2Tokenizer
from tarp.services.datasource.sequence import TabularSequenceSource


from tarp.services.preprocessing.augumentation import (
    CombinationTechnique,
    RandomMutation,
    InsertionDeletion,
    ReverseComplement,
)

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import polars as pl
import numpy as np

from sklearn.neighbors import NearestNeighbors

from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = pl.read_parquet("temp/data/preprocessed/card_amr.parquet")
# Get every column except 'sequence' as label columns
label_columns = [col for col in df.collect_schema().names() if col != "sequence"]

In [None]:
dataset = MultiLabelClassificationDataset(
    TabularSequenceSource(source=Path("temp/data/preprocessed/card_amr.parquet")),
    Dnabert2Tokenizer(),
    sequence_column="sequence",
    label_columns=label_columns,
    maximum_sequence_length=512,
    augumentation=CombinationTechnique(
        [
            RandomMutation(),
            InsertionDeletion(),
            ReverseComplement(0.5),
        ]
    ),
)
from services.datasets.metric.triplet import MultilabelOfflineTripletDataset

metric_dataset = MultilabelOfflineTripletDataset(base_dataset=dataset)

In [4]:
encoder = HyenaEncoder(
        vocabulary_size=dataset.tokenizer.vocab_size,
        embedding_dimension=128,
        hidden_dimension=256,
        padding_id=dataset.tokenizer.pad_token_id,
        number_of_layers=2,
        dropout=0.2,
    )

classification_model = ClassificationModel(
    encoder=encoder,
    number_of_classes=len(label_columns),
)

classification_model.load_state_dict(
    torch.load(r"temp\models\HyenaEncoder_20251010_132058.pt")
)

# Get the encoder part of the model
encoder: HyenaEncoder = classification_model.encoder

In [5]:
# Apply the model to the dataset to get the embeddings
from tqdm.auto import tqdm

# Preallocate numpy array for all embeddings
num_samples = len(dataset)
batch_size = 32
embedding_dim = encoder.encoding_size  # Output dimension of encoder.encode
embeddings = np.empty((num_samples, embedding_dim), dtype=np.float32)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(DEVICE)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
start_idx = 0
encoder.eval()
with torch.no_grad():
    for batch in tqdm(dataloader):
        input_ids = batch["sequence"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        batch_embeddings = encoder.encode(input_ids, attention_mask)
        batch_size_actual = batch_embeddings.shape[0]
        embeddings[start_idx:start_idx + batch_size_actual] = batch_embeddings.cpu().numpy()
        start_idx += batch_size_actual

100%|██████████| 200/200 [00:05<00:00, 34.03it/s]


In [6]:
# Print the shape of the embeddings
print("Embeddings shape:", embeddings.shape)

Embeddings shape: (6392, 128)


In [7]:
# Split the dataset into train and test sets
train_size = int(0.8 * len(dataset))
train_embeddings = embeddings[:train_size]
test_embeddings = embeddings[train_size:]

print("Train embeddings shape:", train_embeddings.shape)
print("Test embeddings shape:", test_embeddings.shape)


Train embeddings shape: (5113, 128)
Test embeddings shape: (1279, 128)


In [8]:
# KNN search
from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier(n_neighbors=5, metric='cosine')

knn.fit(train_embeddings, [
    dataset[i]['labels'].numpy() for i in range(train_size)
])

# Classify test set
predictions = knn.predict(test_embeddings)



In [9]:
# Classification report
from sklearn.metrics import classification_report

print(classification_report(
    [dataset[i + train_size]['labels'].numpy() for i in range(len(test_embeddings))],
    predictions,
    zero_division=0,
    target_names=label_columns
))

                                           precision    recall  f1-score   support

      disinfecting agents and antiseptics       0.50      0.06      0.11        16
                            glycylcycline       0.00      0.00      0.00         6
                     rifamycin antibiotic       0.00      0.00      0.00        13
                     macrolide antibiotic       0.40      0.24      0.30        42
                 streptogramin antibiotic       0.80      0.71      0.75        17
                      pyrazine antibiotic       0.00      0.00      0.00         4
                  tetracycline antibiotic       0.38      0.11      0.18        44
             bicyclomycin-like antibiotic       0.00      0.00      0.00         1
                isoniazid-like antibiotic       0.00      0.00      0.00         5
                nitroimidazole antibiotic       0.00      0.00      0.00         4
                  orthosomycin antibiotic       0.00      0.00      0.00         0
   

In [10]:
# %%
import plotly.express as px

from sklearn.manifold import TSNE

# Run t-SNE on the embeddings
tsne = TSNE(n_components=2, random_state=102, perplexity=30)
train_embeddings_2d = tsne.fit_transform(train_embeddings)


In [11]:
from torch.utils.data import DataLoader

# Collect labels for the training split
labels = []
dataloader_labels = DataLoader(dataset, batch_size=32, shuffle=False)

for i, batch in enumerate(dataloader_labels):
    if i * 32 >= train_size:  # stop after train split
        break
    labels.append(batch["labels"].numpy())

train_labels = np.vstack(labels)[:train_size]  # shape (N, num_labels)
print("Train labels shape:", train_labels.shape)

# Convert multilabel → single label for visualization
train_labels_simple = train_labels.argmax(axis=1)

Train labels shape: (5113, 46)


In [12]:
fig = px.scatter(
    x=train_embeddings_2d[:, 0],
    y=train_embeddings_2d[:, 1],
    color=[label_columns[i] for i in train_labels_simple],
    title="t-SNE visualization of Hyena embeddings",
    labels={"x": "t-SNE dim 1", "y": "t-SNE dim 2", "color": "Gene family"},
    opacity=0.7,
)
# Square figure
fig.update_layout(width=1024, height=768)
fig.show()



In [13]:
# ROC AUC curve for each class
from sklearn.metrics import roc_auc_score
for i, label in enumerate(label_columns):
    y_true = [dataset[j + train_size]['labels'].numpy()[i] for j in range(len(test_embeddings))]
    y_scores = [predictions[j][i] for j in range(len(test_embeddings))]
    auc = roc_auc_score(y_true, y_scores)
    print(f"ROC AUC for {label}: {auc:.4f}")

ROC AUC for disinfecting agents and antiseptics: 0.5309
ROC AUC for glycylcycline: 0.4996
ROC AUC for rifamycin antibiotic: 0.4996
ROC AUC for macrolide antibiotic: 0.6130
ROC AUC for streptogramin antibiotic: 0.8518
ROC AUC for pyrazine antibiotic: 0.5000
ROC AUC for tetracycline antibiotic: 0.5536
ROC AUC for bicyclomycin-like antibiotic: 0.5000
ROC AUC for isoniazid-like antibiotic: 0.5000
ROC AUC for nitroimidazole antibiotic: 0.5000



Only one class is present in y_true. ROC AUC score is not defined in that case.



ROC AUC for orthosomycin antibiotic: nan



Only one class is present in y_true. ROC AUC score is not defined in that case.



ROC AUC for nitrofuran antibiotic: nan
ROC AUC for carbapenem: 0.9583



Only one class is present in y_true. ROC AUC score is not defined in that case.



ROC AUC for pactamycin-like antibiotic: nan



Only one class is present in y_true. ROC AUC score is not defined in that case.



ROC AUC for moenomycin antibiotic: nan
ROC AUC for cycloserine-like antibiotic: 0.5000
ROC AUC for cephalosporin: 0.9604
ROC AUC for diaminopyrimidine antibiotic: 0.6169
ROC AUC for fluoroquinolone antibiotic: 0.7233
ROC AUC for antibiotic without defined classification: 0.5000
ROC AUC for phosphonic acid antibiotic: 0.5905
ROC AUC for pleuromutilin antibiotic: 0.7845
ROC AUC for elfamycin antibiotic: 0.5000
ROC AUC for nucleoside antibiotic: 0.5000
ROC AUC for peptide antibiotic: 0.7624
ROC AUC for phenicol antibiotic: 0.5357
ROC AUC for streptogramin A antibiotic: 0.6496
ROC AUC for aminocoumarin antibiotic: 0.4996
ROC AUC for streptogramin B antibiotic: 0.6364
ROC AUC for sulfonamide antibiotic: 0.5000
ROC AUC for fusidane antibiotic: 0.5000



Only one class is present in y_true. ROC AUC score is not defined in that case.



ROC AUC for zoliflodacin-like antibiotic: nan



Only one class is present in y_true. ROC AUC score is not defined in that case.



ROC AUC for sulfone antibiotic: nan
ROC AUC for thiosemicarbazone antibiotic: 0.5000
ROC AUC for glycopeptide antibiotic: 0.5000
ROC AUC for oxazolidinone antibiotic: 0.5000
ROC AUC for aminoglycoside antibiotic: 0.6085
ROC AUC for thioamide antibiotic: 0.5000



Only one class is present in y_true. ROC AUC score is not defined in that case.



ROC AUC for salicylic acid antibiotic: nan
ROC AUC for mupirocin-like antibiotic: 0.5000
ROC AUC for diarylquinoline antibiotic: 0.5000
ROC AUC for penicillin beta-lactam: 0.9547
ROC AUC for antibacterial free fatty acids: 0.5000
ROC AUC for lincosamide antibiotic: 0.7357
ROC AUC for monobactam: 0.9862
ROC AUC for polyamine antibiotic: 0.5000
